1use crate::{Error, Result};
7use axum::http::{HeaderMap, Method, Uri};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::debug;
14
15#[derive(Debug, Clone)]
17struct StateInstance {
18 resource_id: String,
20 current_state: String,
22 resource_type: String,
24 state_data: HashMap<String, Value>,
26}
27
28impl StateInstance {
29 fn new(resource_id: String, resource_type: String, initial_state: String) -> Self {
30 Self {
31 resource_id,
32 current_state: initial_state,
33 resource_type,
34 state_data: HashMap::new(),
35 }
36 }
37
38 fn transition_to(&mut self, new_state: String) {
39 self.current_state = new_state;
40 }
41}
42
43struct StateMachineManager {
45 instances: Arc<RwLock<HashMap<String, StateInstance>>>,
47}
48
49impl StateMachineManager {
50 fn new() -> Self {
51 Self {
52 instances: Arc::new(RwLock::new(HashMap::new())),
53 }
54 }
55
56 async fn get_or_create_instance(
57 &self,
58 resource_id: String,
59 resource_type: String,
60 initial_state: String,
61 ) -> Result<StateInstance> {
62 let mut instances = self.instances.write().await;
63 if let Some(instance) = instances.get(&resource_id) {
64 Ok(instance.clone())
65 } else {
66 let instance = StateInstance::new(resource_id.clone(), resource_type, initial_state);
67 instances.insert(resource_id, instance.clone());
68 Ok(instance)
69 }
70 }
71
72 async fn update_instance(&self, resource_id: String, instance: StateInstance) -> Result<()> {
73 let mut instances = self.instances.write().await;
74 instances.insert(resource_id, instance);
75 Ok(())
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct StatefulConfig {
82 pub resource_id_extract: ResourceIdExtract,
84 pub resource_type: String,
86 pub state_responses: HashMap<String, StateResponse>,
88 pub transitions: Vec<TransitionTrigger>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(tag = "type", rename_all = "snake_case")]
95pub enum ResourceIdExtract {
96 PathParam {
98 param: String,
100 },
101 JsonPath {
103 path: String,
105 },
106 Header {
108 name: String,
110 },
111 QueryParam {
113 param: String,
115 },
116 Composite {
118 extractors: Vec<ResourceIdExtract>,
120 },
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct StateResponse {
126 pub status_code: u16,
128 pub headers: HashMap<String, String>,
130 pub body_template: String,
132 pub content_type: String,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct TransitionTrigger {
139 #[serde(with = "method_serde")]
141 pub method: Method,
142 pub path_pattern: String,
144 pub from_state: String,
146 pub to_state: String,
148 pub condition: Option<String>,
150}
151
152mod method_serde {
153 use axum::http::Method;
154 use serde::{Deserialize, Deserializer, Serialize, Serializer};
155
156 pub fn serialize<S>(method: &Method, serializer: S) -> Result<S::Ok, S::Error>
157 where
158 S: Serializer,
159 {
160 method.as_str().serialize(serializer)
161 }
162
163 pub fn deserialize<'de, D>(deserializer: D) -> Result<Method, D::Error>
164 where
165 D: Deserializer<'de>,
166 {
167 let s = String::deserialize(deserializer)?;
168 Method::from_bytes(s.as_bytes()).map_err(serde::de::Error::custom)
169 }
170}
171
172pub struct StatefulResponseHandler {
174 state_manager: Arc<StateMachineManager>,
176 configs: Arc<RwLock<HashMap<String, StatefulConfig>>>,
178}
179
180impl StatefulResponseHandler {
181 pub fn new() -> Result<Self> {
183 Ok(Self {
184 state_manager: Arc::new(StateMachineManager::new()),
185 configs: Arc::new(RwLock::new(HashMap::new())),
186 })
187 }
188
189 pub async fn add_config(&self, path_pattern: String, config: StatefulConfig) {
191 let mut configs = self.configs.write().await;
192 configs.insert(path_pattern, config);
193 }
194
195 pub async fn can_handle(&self, _method: &Method, path: &str) -> bool {
197 let configs = self.configs.read().await;
198 for (pattern, _) in configs.iter() {
199 if self.path_matches(pattern, path) {
200 return true;
201 }
202 }
203 false
204 }
205
206 pub async fn process_request(
208 &self,
209 method: &Method,
210 uri: &Uri,
211 headers: &HeaderMap,
212 body: Option<&[u8]>,
213 ) -> Result<Option<StatefulResponse>> {
214 let path = uri.path();
215
216 let config = {
218 let configs = self.configs.read().await;
219 configs
220 .iter()
221 .find(|(pattern, _)| self.path_matches(pattern, path))
222 .map(|(_, config)| config.clone())
223 };
224
225 let config = match config {
226 Some(c) => c,
227 None => return Ok(None),
228 };
229
230 let resource_id =
232 self.extract_resource_id(&config.resource_id_extract, uri, headers, body)?;
233
234 let state_instance = self
236 .state_manager
237 .get_or_create_instance(
238 resource_id.clone(),
239 config.resource_type.clone(),
240 "initial".to_string(), )
242 .await?;
243
244 let new_state = self
246 .check_transitions(&config, method, path, &state_instance, headers, body)
247 .await?;
248
249 let current_state = if let Some(ref state) = new_state {
251 state.clone()
252 } else {
253 state_instance.current_state.clone()
254 };
255
256 let state_response = config.state_responses.get(¤t_state).ok_or_else(|| {
258 Error::generic(format!("No response configuration for state '{}'", current_state))
259 })?;
260
261 if let Some(ref new_state) = new_state {
263 let mut updated_instance = state_instance.clone();
264 updated_instance.transition_to(new_state.clone());
265 self.state_manager
266 .update_instance(resource_id.clone(), updated_instance)
267 .await?;
268 }
269
270 Ok(Some(StatefulResponse {
271 status_code: state_response.status_code,
272 headers: state_response.headers.clone(),
273 body: self.render_body_template(&state_response.body_template, &state_instance)?,
274 content_type: state_response.content_type.clone(),
275 state: current_state,
276 resource_id: resource_id.clone(),
277 }))
278 }
279
280 fn extract_resource_id(
282 &self,
283 extract: &ResourceIdExtract,
284 uri: &Uri,
285 headers: &HeaderMap,
286 body: Option<&[u8]>,
287 ) -> Result<String> {
288 let path = uri.path();
289 match extract {
290 ResourceIdExtract::PathParam { param } => {
291 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
294 if let Some(last) = segments.last() {
295 Ok(last.to_string())
296 } else {
297 Err(Error::generic(format!(
298 "Could not extract path parameter '{}' from path '{}'",
299 param, path
300 )))
301 }
302 }
303 ResourceIdExtract::Header { name } => headers
304 .get(name)
305 .and_then(|v| v.to_str().ok())
306 .map(|s| s.to_string())
307 .ok_or_else(|| Error::generic(format!("Header '{}' not found", name))),
308 ResourceIdExtract::QueryParam { param } => {
309 uri.query()
311 .and_then(|q| {
312 url::form_urlencoded::parse(q.as_bytes())
313 .find(|(k, _)| k == param)
314 .map(|(_, v)| v.to_string())
315 })
316 .ok_or_else(|| Error::generic(format!("Query parameter '{}' not found", param)))
317 }
318 ResourceIdExtract::JsonPath { path: json_path } => {
319 let body_str = body
320 .and_then(|b| std::str::from_utf8(b).ok())
321 .ok_or_else(|| Error::generic("Request body is not valid UTF-8".to_string()))?;
322
323 let json: Value = serde_json::from_str(body_str)
324 .map_err(|e| Error::generic(format!("Invalid JSON body: {}", e)))?;
325
326 self.extract_json_path(&json, json_path)
328 }
329 ResourceIdExtract::Composite { extractors } => {
330 for extract in extractors {
332 if let Ok(id) = self.extract_resource_id(extract, uri, headers, body) {
333 return Ok(id);
334 }
335 }
336 Err(Error::generic("Could not extract resource ID from any source".to_string()))
337 }
338 }
339 }
340
341 fn extract_json_path(&self, json: &Value, path: &str) -> Result<String> {
343 let path = path.trim_start_matches('$').trim_start_matches('.');
344 let parts: Vec<&str> = path.split('.').collect();
345
346 let mut current = json;
347 for part in parts {
348 match current {
349 Value::Object(map) => {
350 current = map
351 .get(part)
352 .ok_or_else(|| Error::generic(format!("Path '{}' not found", path)))?;
353 }
354 Value::Array(arr) => {
355 let idx: usize = part
356 .parse()
357 .map_err(|_| Error::generic(format!("Invalid array index: {}", part)))?;
358 current = arr.get(idx).ok_or_else(|| {
359 Error::generic(format!("Array index {} out of bounds", idx))
360 })?;
361 }
362 _ => {
363 return Err(Error::generic(format!(
364 "Cannot traverse path '{}' at '{}'",
365 path, part
366 )));
367 }
368 }
369 }
370
371 match current {
372 Value::String(s) => Ok(s.clone()),
373 Value::Number(n) => Ok(n.to_string()),
374 _ => {
375 Err(Error::generic(format!("Path '{}' does not point to a string or number", path)))
376 }
377 }
378 }
379
380 async fn check_transitions(
382 &self,
383 config: &StatefulConfig,
384 method: &Method,
385 path: &str,
386 instance: &StateInstance,
387 headers: &HeaderMap,
388 body: Option<&[u8]>,
389 ) -> Result<Option<String>> {
390 for transition in &config.transitions {
391 if transition.method != *method {
393 continue;
394 }
395
396 if !self.path_matches(&transition.path_pattern, path) {
397 continue;
398 }
399
400 if instance.current_state != transition.from_state {
402 continue;
403 }
404
405 if let Some(ref condition) = transition.condition {
407 if !self.evaluate_condition(condition, headers, body)? {
408 continue;
409 }
410 }
411
412 debug!(
414 "State transition triggered: {} -> {} for resource {}",
415 transition.from_state, transition.to_state, instance.resource_id
416 );
417
418 return Ok(Some(transition.to_state.clone()));
419 }
420
421 Ok(None)
422 }
423
424 fn evaluate_condition(
426 &self,
427 condition: &str,
428 _headers: &HeaderMap,
429 body: Option<&[u8]>,
430 ) -> Result<bool> {
431 if condition.starts_with("$.") {
434 let body_str = body
435 .and_then(|b| std::str::from_utf8(b).ok())
436 .ok_or_else(|| Error::generic("Request body is not valid UTF-8".to_string()))?;
437
438 let json: Value = serde_json::from_str(body_str)
439 .map_err(|e| Error::generic(format!("Invalid JSON body: {}", e)))?;
440
441 let value = self.extract_json_path(&json, condition)?;
443 Ok(!value.is_empty() && value != "false" && value != "0")
444 } else {
445 Ok(true)
447 }
448 }
449
450 fn render_body_template(&self, template: &str, instance: &StateInstance) -> Result<String> {
452 let mut result = template.to_string();
453
454 result = result.replace("{{state}}", &instance.current_state);
456
457 result = result.replace("{{resource_id}}", &instance.resource_id);
459
460 for (key, value) in &instance.state_data {
462 let placeholder = format!("{{{{state_data.{}}}}}", key);
463 let value_str = match value {
464 Value::String(s) => s.clone(),
465 Value::Number(n) => n.to_string(),
466 Value::Bool(b) => b.to_string(),
467 _ => serde_json::to_string(value).unwrap_or_default(),
468 };
469 result = result.replace(&placeholder, &value_str);
470 }
471
472 Ok(result)
473 }
474
475 pub async fn process_stub_state(
485 &self,
486 method: &Method,
487 uri: &Uri,
488 headers: &HeaderMap,
489 body: Option<&[u8]>,
490 resource_type: &str,
491 resource_id_extract: &ResourceIdExtract,
492 initial_state: &str,
493 transitions: Option<&[TransitionTrigger]>,
494 ) -> Result<Option<StateInfo>> {
495 let resource_id = self.extract_resource_id(resource_id_extract, uri, headers, body)?;
497
498 let state_instance = self
500 .state_manager
501 .get_or_create_instance(
502 resource_id.clone(),
503 resource_type.to_string(),
504 initial_state.to_string(),
505 )
506 .await?;
507
508 let new_state = if let Some(transition_list) = transitions {
510 let path = uri.path();
511 let mut transitioned_state = None;
514
515 for transition in transition_list {
516 if transition.method != *method {
518 continue;
519 }
520
521 if !self.path_matches(&transition.path_pattern, path) {
522 continue;
523 }
524
525 if state_instance.current_state != transition.from_state {
527 continue;
528 }
529
530 if let Some(ref condition) = transition.condition {
532 if !self.evaluate_condition(condition, headers, body)? {
533 continue;
534 }
535 }
536
537 debug!(
539 "State transition triggered in stub processing: {} -> {} for resource {}",
540 transition.from_state, transition.to_state, resource_id
541 );
542
543 transitioned_state = Some(transition.to_state.clone());
544 break; }
546
547 transitioned_state
548 } else {
549 None
550 };
551
552 let final_state = if let Some(ref new_state) = new_state {
554 let mut updated_instance = state_instance.clone();
555 updated_instance.transition_to(new_state.clone());
556 self.state_manager
557 .update_instance(resource_id.clone(), updated_instance)
558 .await?;
559 new_state.clone()
560 } else {
561 state_instance.current_state.clone()
562 };
563
564 Ok(Some(StateInfo {
565 resource_id: resource_id.clone(),
566 current_state: final_state,
567 state_data: state_instance.state_data.clone(),
568 }))
569 }
570
571 pub async fn update_resource_state(
573 &self,
574 resource_id: &str,
575 resource_type: &str,
576 new_state: &str,
577 ) -> Result<()> {
578 let mut instances = self.state_manager.instances.write().await;
579 if let Some(instance) = instances.get_mut(resource_id) {
580 if instance.resource_type == resource_type {
581 instance.transition_to(new_state.to_string());
582 return Ok(());
583 }
584 }
585 Err(Error::generic(format!(
586 "Resource '{}' of type '{}' not found",
587 resource_id, resource_type
588 )))
589 }
590
591 pub async fn get_resource_state(
593 &self,
594 resource_id: &str,
595 resource_type: &str,
596 ) -> Result<Option<StateInfo>> {
597 let instances = self.state_manager.instances.read().await;
598 if let Some(instance) = instances.get(resource_id) {
599 if instance.resource_type == resource_type {
600 return Ok(Some(StateInfo {
601 resource_id: resource_id.to_string(),
602 current_state: instance.current_state.clone(),
603 state_data: instance.state_data.clone(),
604 }));
605 }
606 }
607 Ok(None)
608 }
609
610 fn path_matches(&self, pattern: &str, path: &str) -> bool {
612 let pattern_regex = pattern.replace("{", "(?P<").replace("}", ">[^/]+)").replace("*", ".*");
614 let regex = regex::Regex::new(&format!("^{}$", pattern_regex));
615 match regex {
616 Ok(re) => re.is_match(path),
617 Err(_) => pattern == path, }
619 }
620}
621
622#[derive(Debug, Clone)]
624pub struct StateInfo {
625 pub resource_id: String,
627 pub current_state: String,
629 pub state_data: HashMap<String, Value>,
631}
632
633#[derive(Debug, Clone)]
635pub struct StatefulResponse {
636 pub status_code: u16,
638 pub headers: HashMap<String, String>,
640 pub body: String,
642 pub content_type: String,
644 pub state: String,
646 pub resource_id: String,
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[test]
655 fn test_path_matching() {
656 let handler = StatefulResponseHandler::new().unwrap();
657
658 assert!(handler.path_matches("/orders/{id}", "/orders/123"));
659 assert!(handler.path_matches("/api/*", "/api/users"));
660 assert!(!handler.path_matches("/orders/{id}", "/orders/123/items"));
661 }
662}