rohas_runtime/
python_runtime.rs

1use crate::error::{Result, RuntimeError};
2use crate::handler::{HandlerContext, HandlerResult};
3use pyo3::prelude::*;
4use pyo3::types::{PyDict, PyModule, PyTuple};
5use rohas_codegen::templates;
6use std::path::{Path, PathBuf};
7use std::sync::{Arc, Mutex};
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11#[pyclass]
12struct RohasLogFn {
13    handler_name: String,
14}
15
16#[pymethods]
17impl RohasLogFn {
18    fn __call__(
19        &self,
20        level: String,
21        handler: String,
22        message: String,
23        fields: Bound<'_, PyDict>,
24    ) -> PyResult<()> {
25        let mut field_map = std::collections::HashMap::new();
26        for (key, value) in fields.iter() {
27            let key_str = key.extract::<String>()?;
28            let value_str = format!("{:?}", value);
29            field_map.insert(key_str, value_str);
30        }
31        
32        let span = tracing::span!(
33            tracing::Level::INFO,
34            "handler_log",
35            handler = %handler
36        );
37        let _enter = span.enter();
38        
39        match level.as_str() {
40            "error" => tracing::error!(message = %message, ?field_map),
41            "warn" => tracing::warn!(message = %message, ?field_map),
42            "info" => tracing::info!(message = %message, ?field_map),
43            "debug" => tracing::debug!(message = %message, ?field_map),
44            "trace" => tracing::trace!(message = %message, ?field_map),
45            _ => tracing::info!(message = %message, ?field_map),
46        }
47        
48        Ok(())
49    }
50}
51
52pub struct PythonRuntime {
53    modules: Arc<RwLock<std::collections::HashMap<String, Py<PyModule>>>>,
54    project_root: Arc<Mutex<Option<PathBuf>>>,
55}
56
57impl PythonRuntime {
58    pub fn new() -> Result<Self> {
59        Python::with_gil(|_| {
60            info!("Python runtime initialized");
61        });
62
63        Ok(Self {
64            modules: Arc::new(RwLock::new(std::collections::HashMap::new())),
65            project_root: Arc::new(Mutex::new(None)),
66        })
67    }
68
69    pub fn set_project_root(&mut self, root: PathBuf) {
70        let mut project_root = self.project_root.lock().unwrap();
71        *project_root = Some(root);
72    }
73
74    pub async fn execute_handler(
75        &self,
76        handler_path: &Path,
77        context: HandlerContext,
78    ) -> Result<HandlerResult> {
79        let start = std::time::Instant::now();
80        let handler_path = handler_path.to_path_buf();
81        let handler_name = context.handler_name.clone();
82        let project_root = self.project_root.lock().unwrap().clone();
83
84        debug!("Executing Python handler: {:?}", handler_path);
85
86        let task = tokio::task::spawn_blocking(move || {
87            Python::with_gil(|py| {
88                Self::execute_handler_sync(
89                    py,
90                    &handler_path,
91                    &handler_name,
92                    &context,
93                    project_root.as_ref(),
94                )
95            })
96        });
97
98        let result = tokio::time::timeout(std::time::Duration::from_secs(30), task)
99            .await
100            .map_err(|_| RuntimeError::ExecutionFailed("Handler execution timeout (30s)".into()))?
101            .map_err(|e| RuntimeError::ExecutionFailed(format!("Task join error: {}", e)))??;
102
103        let execution_time_ms = start.elapsed().as_millis() as u64;
104        Ok(HandlerResult {
105            execution_time_ms,
106            ..result
107        })
108    }
109
110    fn execute_handler_sync(
111        py: Python<'_>,
112        handler_path: &Path,
113        handler_name: &str,
114        context: &HandlerContext,
115        project_root: Option<&PathBuf>,
116    ) -> Result<HandlerResult> {
117        let sys = py.import("sys")?;
118        let sys_path = sys.getattr("path")?;
119
120        if let Some(parent) = handler_path.parent() {
121            sys_path.call_method1("insert", (0, parent.to_str().unwrap()))?;
122        }
123
124        if let Some(root) = project_root {
125            let src_path = root.join("src");
126            if src_path.exists() {
127                let src_path_str = src_path.to_str().unwrap();
128                let path_list: Vec<String> = sys_path.extract()?;
129                if !path_list.contains(&src_path_str.to_string()) {
130                    sys_path.call_method1("append", (src_path_str,))?;
131                    debug!("Added to sys.path (appended): {:?}", src_path);
132                }
133            }
134        }
135
136        let module_name = handler_path
137            .file_stem()
138            .and_then(|s| s.to_str())
139            .ok_or_else(|| RuntimeError::ExecutionFailed("Invalid module name".into()))?;
140
141        // Hot-reload support for Python handlers:
142        // - Invalidate import caches
143        let importlib = py.import("importlib")?;
144        let _ = importlib.call_method0("invalidate_caches");
145
146        if let Ok(modules_dict) = sys.getattr("modules") {
147            let _ = modules_dict.del_item(module_name);
148        }
149
150        let module = PyModule::import(py, module_name).map_err(|e| {
151            RuntimeError::ExecutionFailed(format!("Failed to import module: {}", e))
152        })?;
153
154        let is_event_handler = handler_path
155            .parent()
156            .and_then(|p| p.file_name())
157            .and_then(|n| n.to_str())
158            .map(|n| n == "events")
159            .unwrap_or(false);
160
161        let is_websocket_handler = handler_path
162            .parent()
163            .and_then(|p| p.file_name())
164            .and_then(|n| n.to_str())
165            .map(|n| n == "websockets")
166            .unwrap_or(false);
167
168        let is_middleware = handler_path
169            .parent()
170            .and_then(|p| p.file_name())
171            .and_then(|n| n.to_str())
172            .map(|n| n == "middlewares")
173            .unwrap_or(false);
174
175        let function_name = if is_event_handler || is_websocket_handler {
176            let direct_name = handler_name.to_string();
177            let handle_name = format!("handle_{}", handler_name);
178            
179            match module.hasattr(handle_name.as_str()) {
180                Ok(true) => {
181                    debug!("Using function name '{}' for handler '{}'", handle_name, handler_name);
182                    handle_name
183                }
184                _ => {
185                    debug!("Using function name '{}' for handler '{}' (handle_ variant not found)", direct_name, handler_name);
186                    direct_name
187                }
188            }
189        } else if is_middleware {
190            format!("{}_middleware", templates::to_snake_case(handler_name))
191        } else {
192            Self::extract_function_name(handler_name)
193        };
194
195        let handler_fn = module.getattr(function_name.as_str()).map_err(|e| {
196            RuntimeError::HandlerNotFound(format!("Function '{}' not found: {}", function_name, e))
197        })?;
198
199        let inspect = py.import("inspect")?;
200        let sig = inspect.call_method1("signature", (handler_fn.as_any(),))?;
201        let params = sig.getattr("parameters")?;
202        let param_count = params.call_method0("__len__")?.extract::<usize>()?;
203
204        let state_module = py.import("generated.state")?;
205        let state_class = state_module.getattr("State")?;
206        
207        let log_fn_instance = Py::new(py, RohasLogFn {
208            handler_name: handler_name.to_string(),
209        })?;
210        
211        let log_fn_py: PyObject = log_fn_instance.into();
212        let state_obj = state_class.call1((handler_name, log_fn_py))?;
213        let state_obj_for_triggers = state_obj.clone();
214
215        let result = if param_count == 0 {
216            handler_fn
217                .call0()
218                .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
219        } else if is_event_handler {
220            let event_obj =
221                Self::instantiate_event_object(py, context, &handler_path).map_err(|e| {
222                    RuntimeError::ExecutionFailed(format!(
223                        "Failed to instantiate event object: {}",
224                        e
225                    ))
226                })?;
227
228            if param_count >= 2 {
229                handler_fn.call1((event_obj, state_obj)).map_err(|e| {
230                    RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e))
231                })?
232            } else {
233                handler_fn.call1((event_obj,)).map_err(|e| {
234                    RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e))
235                })?
236            }
237        } else if is_websocket_handler {
238            Self::call_websocket_handler(py, handler_fn, context, param_count, state_obj)
239                .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
240        } else if param_count >= 2 {
241            let request_dict = Self::build_request_dict(py, context)?;
242            let request_obj = Self::instantiate_request_class(py, handler_name, &request_dict)
243                .unwrap_or_else(|_| request_dict.clone().into_any());
244
245            handler_fn
246                .call1((request_obj, state_obj))
247                .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
248        } else {
249            let request_dict = Self::build_request_dict(py, context)?;
250            let request_obj = Self::instantiate_request_class(py, handler_name, &request_dict)
251                .unwrap_or_else(|_| request_dict.clone().into_any());
252
253            handler_fn
254                .call1((request_obj,))
255                .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
256        };
257
258        let final_result = if Self::is_coroutine(py, &result)? {
259            debug!("Handler is async, awaiting coroutine");
260            Self::await_coroutine(py, result)?
261        } else {
262            result
263        };
264
265        let json_str: String = if final_result.is_none() {
266            "null".to_string()
267        } else {
268            let json_module = py.import("json")?;
269
270            let json_ready = if let Ok(model_dump) = final_result.getattr("model_dump") {
271                match model_dump.call0() {
272                    Ok(dumped) => dumped,
273                    Err(_) => final_result,
274                }
275            } else if let Ok(dict_method) = final_result.getattr("dict") {
276                match dict_method.call0() {
277                    Ok(dumped) => dumped,
278                    Err(_) => final_result,
279                }
280            } else {
281                let dataclasses = py.import("dataclasses")?;
282                if dataclasses
283                    .call_method1("is_dataclass", (final_result.as_any(),))
284                    .and_then(|r| r.extract::<bool>())
285                    .unwrap_or(false)
286                {
287                    match dataclasses.call_method1("asdict", (final_result.as_any(),)) {
288                        Ok(dict) => dict,
289                        Err(_) => final_result,
290                    }
291                } else {
292                    final_result
293                }
294            };
295
296            match json_module.call_method1("dumps", (json_ready.as_any(),)) {
297                Ok(json_result) => json_result.extract::<String>()?,
298                Err(e) => {
299                    debug!("Failed to serialize response to JSON: {}, falling back to string representation", e);
300                    json_ready.str()?.to_string()
301                }
302            }
303        };
304
305        let data: serde_json::Value =
306            serde_json::from_str(&json_str).unwrap_or(serde_json::json!({"raw": json_str}));
307
308        let mut result = HandlerResult::success(data, 0);
309        if param_count >= 2 || (is_event_handler && param_count >= 2) {
310            if let Ok(triggers_py) = state_obj_for_triggers.call_method0("get_triggers") {
311                if let Ok(triggers_list) = triggers_py.downcast::<pyo3::types::PyList>() {
312                    for trigger_item in triggers_list.iter() {
313                        let event_name_py = trigger_item.getattr("event_name");
314                        let payload_py = trigger_item.getattr("payload");
315
316                        if let (Ok(event_name_py), Ok(payload_py)) = (event_name_py, payload_py) {
317                            if let Ok(event_name) = event_name_py.extract::<String>() {
318                                let json_module = py.import("json")?;
319                                if let Ok(payload_str) = json_module
320                                    .call_method1("dumps", (payload_py,))?
321                                    .extract::<String>()
322                                {
323                                    if let Ok(payload_value) =
324                                        serde_json::from_str::<serde_json::Value>(&payload_str)
325                                    {
326                                        debug!(
327                                            "Extracted manual trigger: {} with payload",
328                                            event_name
329                                        );
330                                        result = result.with_trigger(event_name, payload_value);
331                                    } else {
332                                        debug!(
333                                            "Failed to parse payload JSON for trigger: {}",
334                                            event_name
335                                        );
336                                    }
337                                } else {
338                                    debug!(
339                                        "Failed to serialize payload to JSON for trigger: {}",
340                                        event_name
341                                    );
342                                }
343                            } else {
344                                debug!("Failed to extract event_name from TriggeredEvent");
345                            }
346                        } else {
347                            debug!("Failed to get event_name or payload from TriggeredEvent");
348                        }
349                    }
350                } else {
351                    debug!("get_triggers() did not return a list");
352                }
353            } else {
354                debug!("Failed to call get_triggers() on State object");
355            }
356
357            if let Ok(payloads_py) =
358                state_obj_for_triggers.call_method0("get_all_auto_trigger_payloads")
359            {
360                if let Ok(payloads_dict) = payloads_py.downcast::<PyDict>() {
361                    for item in payloads_dict.iter() {
362                        let key = item.0;
363                        let value = item.1;
364                        if let Ok(event_name) = key.extract::<String>() {
365                            let json_module = py.import("json")?;
366                            if let Ok(payload_str) = json_module
367                                .call_method1("dumps", (value,))?
368                                .extract::<String>()
369                            {
370                                if let Ok(payload_value) =
371                                    serde_json::from_str::<serde_json::Value>(&payload_str)
372                                {
373                                    result =
374                                        result.with_auto_trigger_payload(event_name, payload_value);
375                                }
376                            }
377                        }
378                    }
379                }
380            }
381        }
382
383        Ok(result)
384    }
385
386    fn build_request_dict<'py>(
387        py: Python<'py>,
388        context: &HandlerContext,
389    ) -> PyResult<Bound<'py, PyDict>> {
390        let dict = PyDict::new(py);
391
392        let json_str = serde_json::to_string(&context.payload)
393            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
394        let json_module = py.import("json")?;
395        let payload_py = json_module.call_method1("loads", (json_str,))?;
396
397        if let Ok(payload_dict) = payload_py.downcast::<PyDict>() {
398            if !payload_dict.is_empty() {
399                if let Ok(body_value) = payload_dict.get_item("body") {
400                    if let Some(body) = body_value {
401                        dict.set_item("body", body)?;
402                    }
403                } else {
404                    for item in payload_dict.iter() {
405                        let key = item.0;
406                        let value = item.1;
407
408                        dict.set_item(key, value)?;
409                    }
410                }
411            }
412        }
413
414        let query_params_dict = PyDict::new(py);
415        for (key, value) in &context.query_params {
416            query_params_dict.set_item(key, value)?;
417        }
418        dict.set_item("query_params", query_params_dict)?;
419
420        Ok(dict)
421    }
422
423    fn instantiate_request_class<'py>(
424        py: Python<'py>,
425        handler_name: &str,
426        request_dict: &Bound<'py, PyDict>,
427    ) -> PyResult<Bound<'py, pyo3::PyAny>> {
428        let class_name = Self::handler_name_to_request_class(handler_name);
429
430        let module_name = handler_name.to_lowercase();
431
432        let import_path = format!("generated.api.{}", module_name);
433        let api_module = py.import(import_path.as_str())?;
434        let request_class = api_module.getattr(class_name.as_str())?;
435
436        request_class.call((), Some(request_dict))
437    }
438
439    fn handler_name_to_request_class(handler_name: &str) -> String {
440        let pascal_case = handler_name
441            .split('_')
442            .map(|word| {
443                let mut chars = word.chars();
444                match chars.next() {
445                    None => String::new(),
446                    Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
447                }
448            })
449            .collect::<String>();
450
451        format!("{}Request", pascal_case)
452    }
453
454    fn is_primitive_type(type_name: &str) -> bool {
455        matches!(
456            type_name,
457            "String" | "Int" | "Float" | "Boolean" | "Bool" | "DateTime" | "Date"
458        )
459    }
460
461    fn extract_primitive_value<'py>(
462        py: Python<'py>,
463        payload_dict: &Bound<'py, pyo3::PyAny>,
464        payload_type: &str,
465    ) -> PyResult<Bound<'py, pyo3::PyAny>> {
466        if let Ok(payload_dict_ref) = payload_dict.downcast::<PyDict>() {
467            if let Ok(Some(value)) = payload_dict_ref.get_item("payload") {
468                debug!(
469                    "Extracted primitive value from payload dict for type: {}",
470                    payload_type
471                );
472                return Ok(value);
473            }
474            let len = payload_dict_ref.len();
475            if len == 1 {
476                if let Some((_, value)) = payload_dict_ref.iter().next() {
477                    debug!(
478                        "Extracted primitive value from single-key dict for type: {}",
479                        payload_type
480                    );
481                    return Ok(value);
482                }
483            }
484        }
485        Ok(payload_dict.clone())
486    }
487
488    fn instantiate_event_object<'py>(
489        py: Python<'py>,
490        context: &HandlerContext,
491        _handler_path: &Path,
492    ) -> PyResult<Bound<'py, pyo3::PyAny>> {
493        let event_name = context.metadata.get("event_name").ok_or_else(|| {
494            pyo3::exceptions::PyValueError::new_err("Event name not found in context metadata")
495        })?;
496        let payload_type = context.metadata.get("event_payload_type");
497
498        let json_str = serde_json::to_string(&context.payload)
499            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
500        let json_module = py.import("json")?;
501        let payload_dict = json_module.call_method1("loads", (json_str,))?;
502        let payload_dict_clone = payload_dict.clone();
503
504        let datetime_module = py.import("datetime")?;
505        let now = datetime_module
506            .getattr("datetime")?
507            .getattr("now")?
508            .call0()?;
509        let now_clone = now.clone();
510
511        let convert_snake_to_camel = |dict: &Bound<'_, PyDict>| -> PyResult<Bound<'_, PyDict>> {
512            let camel_dict = PyDict::new(py);
513            for (key, value) in dict.iter() {
514                if let Ok(key_str) = key.extract::<String>() {
515                    let camel_key = if key_str.contains('_') {
516                        let parts: Vec<&str> = key_str.split('_').collect();
517                        let mut camel = String::new();
518                        for (i, part) in parts.iter().enumerate() {
519                            if i == 0 {
520                                camel.push_str(part);
521                            } else {
522                                let mut chars = part.chars();
523                                if let Some(first) = chars.next() {
524                                    camel.push(first.to_uppercase().next().unwrap());
525                                    camel.push_str(&chars.as_str());
526                                }
527                            }
528                        }
529                        camel
530                    } else {
531                        key_str
532                    };
533                    camel_dict.set_item(camel_key, value)?;
534                } else {
535                    camel_dict.set_item(key, value)?;
536                }
537            }
538            Ok(camel_dict)
539        };
540
541        let payload_obj = if let Some(payload_type_name) = payload_type {
542            if Self::is_primitive_type(payload_type_name) {
543                debug!(
544                    "Payload type {} is primitive, extracting value",
545                    payload_type_name
546                );
547                Self::extract_primitive_value(py, &payload_dict, payload_type_name)?
548            } else {
549                let payload_type_snake = templates::to_snake_case(payload_type_name);
550                let model_module_path = format!("generated.models.{}", payload_type_snake);
551
552                match py.import(&model_module_path) {
553                    Ok(model_module) => match model_module.getattr(payload_type_name.as_str()) {
554                        Ok(model_class) => {
555                            if let Ok(payload_dict_ref) = payload_dict.downcast::<PyDict>() {
556                                match convert_snake_to_camel(&payload_dict_ref) {
557                                    Ok(camel_payload_dict) => {
558                                        match model_class.call((), Some(&camel_payload_dict)) {
559                                            Ok(model_obj) => {
560                                                debug!(
561                                                    "Successfully instantiated payload model: {}",
562                                                    payload_type_name
563                                                );
564                                                model_obj.into_any()
565                                            }
566                                            Err(e) => {
567                                                debug!("Direct call failed for {}, trying model_validate: {}", payload_type_name, e);
568                                                if let Ok(model_validate) =
569                                                    model_class.getattr("model_validate")
570                                                {
571                                                    match model_validate
572                                                        .call1((camel_payload_dict,))
573                                                    {
574                                                        Ok(model_obj) => {
575                                                            debug!("Successfully instantiated payload model via model_validate: {}", payload_type_name);
576                                                            model_obj.into_any()
577                                                        }
578                                                        Err(e2) => {
579                                                            debug!("model_validate also failed for {}: {}", payload_type_name, e2);
580                                                            payload_dict.clone()
581                                                        }
582                                                    }
583                                                } else {
584                                                    payload_dict.clone()
585                                                }
586                                            }
587                                        }
588                                    }
589                                    Err(e) => {
590                                        debug!("Failed to convert field names: {}, trying with original dict", e);
591                                        match model_class.call((), Some(&payload_dict_ref)) {
592                                            Ok(model_obj) => {
593                                                debug!("Successfully instantiated payload model with original dict: {}", payload_type_name);
594                                                model_obj.into_any()
595                                            }
596                                            Err(_) => payload_dict.clone(),
597                                        }
598                                    }
599                                }
600                            } else {
601                                if let Ok(model_validate) = model_class.getattr("model_validate") {
602                                    match model_validate.call1((payload_dict.clone(),)) {
603                                        Ok(model_obj) => {
604                                            debug!("Successfully instantiated payload model via model_validate (PyAny): {}", payload_type_name);
605                                            model_obj.into_any()
606                                        }
607                                        Err(e) => {
608                                            debug!(
609                                                "model_validate failed for {} (PyAny): {}",
610                                                payload_type_name, e
611                                            );
612                                            payload_dict.clone()
613                                        }
614                                    }
615                                } else {
616                                    payload_dict.clone()
617                                }
618                            }
619                        }
620                        Err(e) => {
621                            debug!(
622                                "Failed to get model class {} from module: {}",
623                                payload_type_name, e
624                            );
625                            payload_dict.clone()
626                        }
627                    },
628                    Err(e) => {
629                        debug!("Failed to import model module {}: {}", model_module_path, e);
630                        payload_dict.clone()
631                    }
632                }
633            }
634        } else {
635            debug!("No payload type in metadata, using dict as-is");
636            payload_dict.clone()
637        };
638
639        let event_name_snake = templates::to_snake_case(event_name);
640        let event_module_path = format!("generated.events.{}", event_name_snake);
641
642        match py.import(&event_module_path) {
643            Ok(event_module) => match event_module.getattr(event_name.as_str()) {
644                Ok(event_class) => {
645                    let event_dict = PyDict::new(py);
646                    event_dict.set_item("payload", payload_obj)?;
647                    event_dict.set_item("timestamp", &now)?;
648
649                    let mut event_dict_for_direct = None;
650                    if let Ok(model_validate) = event_class.getattr("model_validate") {
651                        debug!("Attempting model_validate for event {}", event_name);
652                        match model_validate.call1((event_dict,)) {
653                            Ok(event_obj) => {
654                                debug!("model_validate call succeeded for event {}", event_name);
655                                match event_obj.getattr("payload") {
656                                    Ok(_) => {
657                                        debug!("Event object has payload attribute - instantiation successful via model_validate");
658                                        return Ok(event_obj);
659                                    }
660                                    Err(e) => {
661                                        debug!("Event object from model_validate missing payload attribute: {}", e);
662                                    }
663                                }
664                            }
665                            Err(e) => {
666                                let error_msg = format!("{}", e);
667                                debug!(
668                                    "model_validate failed for event {}: {}",
669                                    event_name, error_msg
670                                );
671                                let py_err = e.value(py);
672                                if let Ok(err_str) = py_err.str() {
673                                    debug!("Python error details: {}", err_str.to_string_lossy());
674                                }
675                                let json_str_direct = serde_json::to_string(&context.payload)
676                                    .map_err(|e| {
677                                        pyo3::exceptions::PyValueError::new_err(e.to_string())
678                                    })?;
679                                let payload_for_direct =
680                                    json_module.call_method1("loads", (json_str_direct,))?;
681                                let payload_for_direct_value =
682                                    if let Some(payload_type_name) = payload_type {
683                                        if Self::is_primitive_type(payload_type_name) {
684                                            Self::extract_primitive_value(
685                                                py,
686                                                &payload_for_direct,
687                                                payload_type_name,
688                                            )?
689                                        } else {
690                                            payload_for_direct
691                                        }
692                                    } else {
693                                        payload_for_direct
694                                    };
695                                let event_dict2 = PyDict::new(py);
696                                event_dict2.set_item("payload", payload_for_direct_value)?;
697                                event_dict2.set_item("timestamp", &now_clone)?;
698                                event_dict_for_direct = Some(event_dict2);
699                            }
700                        }
701                    } else {
702                        debug!(
703                            "model_validate method not found for event {}, using direct call",
704                            event_name
705                        );
706                        event_dict_for_direct = Some(event_dict);
707                    }
708
709                    if let Some(event_dict2) = event_dict_for_direct {
710                        debug!("Attempting direct call for event {}", event_name);
711                        match event_class.call((), Some(&event_dict2)) {
712                            Ok(event_obj) => {
713                                debug!("Direct call succeeded for event {}", event_name);
714                                match event_obj.getattr("payload") {
715                                    Ok(_) => {
716                                        debug!("Event object has payload attribute - instantiation successful via direct call");
717                                        return Ok(event_obj);
718                                    }
719                                    Err(e) => {
720                                        debug!("Event object missing payload attribute: {}", e);
721                                    }
722                                }
723                            }
724                            Err(e) => {
725                                let error_msg = format!("{}", e);
726                                debug!(
727                                    "Direct call also failed for event {}, error: {}",
728                                    event_name, error_msg
729                                );
730                                let py_err = e.value(py);
731                                if let Ok(err_str) = py_err.str() {
732                                    debug!("Python error details: {}", err_str.to_string_lossy());
733                                }
734                            }
735                        }
736                    }
737                }
738                Err(e) => {
739                    debug!(
740                        "Failed to get event class {} from module: {}",
741                        event_name, e
742                    );
743                }
744            },
745            Err(e) => {
746                debug!("Failed to import event module {}: {}", event_module_path, e);
747            }
748        }
749
750        debug!(
751            "Attempting final fallback instantiation for event: {} with dict payload",
752            event_name
753        );
754        let final_payload_value = if let Some(payload_type_name) = payload_type {
755            if Self::is_primitive_type(payload_type_name) {
756                Self::extract_primitive_value(py, &payload_dict_clone, payload_type_name)?
757            } else {
758                payload_dict_clone
759            }
760        } else {
761            payload_dict_clone
762        };
763        let final_event_dict = PyDict::new(py);
764        final_event_dict.set_item("payload", final_payload_value)?;
765        final_event_dict.set_item("timestamp", &now_clone)?;
766
767        if let Ok(event_module) = py.import(&event_module_path) {
768            if let Ok(event_class) = event_module.getattr(event_name.as_str()) {
769                let mut final_dict_for_direct = None;
770                if let Ok(model_validate) = event_class.getattr("model_validate") {
771                    match model_validate.call1((final_event_dict,)) {
772                        Ok(event_obj) => {
773                            if event_obj.getattr("payload").is_ok() {
774                                debug!("Successfully instantiated event via fallback model_validate with dict payload");
775                                return Ok(event_obj);
776                            }
777                        }
778                        Err(e) => {
779                            debug!("Fallback model_validate failed: {}", e);
780                            let json_str_fallback2 = serde_json::to_string(&context.payload)
781                                .map_err(|e| {
782                                    pyo3::exceptions::PyValueError::new_err(e.to_string())
783                                })?;
784                            let payload_dict_fallback2 =
785                                json_module.call_method1("loads", (json_str_fallback2,))?;
786                            let payload_fallback2_value =
787                                if let Some(payload_type_name) = payload_type {
788                                    if Self::is_primitive_type(payload_type_name) {
789                                        Self::extract_primitive_value(
790                                            py,
791                                            &payload_dict_fallback2,
792                                            payload_type_name,
793                                        )?
794                                    } else {
795                                        payload_dict_fallback2
796                                    }
797                                } else {
798                                    payload_dict_fallback2
799                                };
800                            let final_dict2 = PyDict::new(py);
801                            final_dict2.set_item("payload", payload_fallback2_value)?;
802                            final_dict2.set_item("timestamp", &now_clone)?;
803                            final_dict_for_direct = Some(final_dict2);
804                        }
805                    }
806                } else {
807                    final_dict_for_direct = Some(final_event_dict);
808                }
809                if let Some(final_dict2) = final_dict_for_direct {
810                    match event_class.call((), Some(&final_dict2)) {
811                        Ok(event_obj) => {
812                            if event_obj.getattr("payload").is_ok() {
813                                debug!("Successfully instantiated event via fallback direct call with dict payload");
814                                return Ok(event_obj);
815                            }
816                        }
817                        Err(e) => {
818                            debug!("Fallback direct call also failed: {}", e);
819                        }
820                    }
821                }
822            }
823        }
824
825        Err(pyo3::exceptions::PyValueError::new_err(format!(
826            "Failed to instantiate event object {}: All instantiation methods failed. Check debug logs for details.",
827            event_name
828        )))
829    }
830
831    fn context_to_py_dict<'py>(
832        py: Python<'py>,
833        context: &HandlerContext,
834    ) -> PyResult<Bound<'py, PyDict>> {
835        let dict = PyDict::new(py);
836
837        dict.set_item("handler_name", &context.handler_name)?;
838
839        let json_str = serde_json::to_string(&context.payload)
840            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
841        let json_module = py.import("json")?;
842        let payload_py = json_module.call_method1("loads", (json_str,))?;
843        dict.set_item("payload", payload_py)?;
844
845        let query_params_dict = PyDict::new(py);
846        for (key, value) in &context.query_params {
847            query_params_dict.set_item(key, value)?;
848        }
849        dict.set_item("query_params", query_params_dict)?;
850
851        dict.set_item("timestamp", &context.timestamp)?;
852
853        let metadata_dict = PyDict::new(py);
854        for (key, value) in &context.metadata {
855            metadata_dict.set_item(key, value)?;
856        }
857        dict.set_item("metadata", metadata_dict)?;
858
859        Ok(dict)
860    }
861
862    fn is_coroutine(py: Python<'_>, obj: &Bound<'_, pyo3::PyAny>) -> PyResult<bool> {
863        let inspect = py.import("inspect")?;
864        let is_coro = inspect.call_method1("iscoroutine", (obj,))?;
865        is_coro.extract::<bool>()
866    }
867
868    fn await_coroutine<'py>(
869        py: Python<'py>,
870        coro: Bound<'py, pyo3::PyAny>,
871    ) -> PyResult<Bound<'py, pyo3::PyAny>> {
872        let asyncio = py.import("asyncio")?;
873
874        let loop_result = asyncio.call_method0("get_event_loop");
875
876        let result = if let Ok(event_loop) = loop_result {
877            event_loop.call_method1("run_until_complete", (coro,))?
878        } else {
879            let new_loop = asyncio.call_method0("new_event_loop")?;
880            asyncio.call_method1("set_event_loop", (new_loop.as_any(),))?;
881            let result = new_loop.call_method1("run_until_complete", (coro,))?;
882            new_loop.call_method0("close")?;
883            result
884        };
885
886        Ok(result)
887    }
888
889    fn call_websocket_handler<'py>(
890        py: Python<'py>,
891        handler_fn: Bound<'py, pyo3::PyAny>,
892        context: &HandlerContext,
893        param_count: usize,
894        state_obj: Bound<'py, pyo3::PyAny>,
895    ) -> PyResult<Bound<'py, pyo3::PyAny>> {
896        let json_str = serde_json::to_string(&context.payload)
897            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
898        let json_module = py.import("json")?;
899        let payload_dict = json_module.call_method1("loads", (json_str,))?;
900        let payload_dict = payload_dict.downcast::<PyDict>().map_err(|e| {
901            pyo3::exceptions::PyValueError::new_err(format!("Payload is not a dict: {}", e))
902        })?;
903
904        let ws_name = context
905            .metadata
906            .get("websocket_name")
907            .map(|s| s.as_str())
908            .unwrap_or("HelloWorld");
909
910        let ws_name_snake = templates::to_snake_case(ws_name);
911        let ws_module_path = format!("generated.websockets.{}", ws_name_snake);
912
913        let (connection_class, message_class) = match py.import(&ws_module_path) {
914            Ok(ws_module) => {
915                let conn_class = ws_module.getattr(&format!("{}Connection", ws_name)).ok();
916                let msg_class = ws_module.getattr(&format!("{}Message", ws_name)).ok();
917                (conn_class, msg_class)
918            }
919            Err(_) => (None, None),
920        };
921
922        let connection_obj = if let Some(conn_class) = connection_class {
923            if let Ok(connection_dict) = payload_dict.get_item("connection") {
924                if let Some(conn_dict) = connection_dict {
925                    if let Ok(conn_dict) = conn_dict.downcast::<PyDict>() {
926                        if let Ok(model_validate) = conn_class.getattr("model_validate") {
927                            if let Ok(conn_obj) = model_validate.call1((conn_dict,)) {
928                                conn_obj
929                            } else {
930                                conn_class
931                                    .call((), Some(conn_dict))
932                                    .unwrap_or_else(|_| conn_dict.clone().into_any())
933                            }
934                        } else {
935                            conn_class
936                                .call((), Some(conn_dict))
937                                .unwrap_or_else(|_| conn_dict.clone().into_any())
938                        }
939                    } else {
940                        conn_dict.clone().into_any()
941                    }
942                } else {
943                    if let Ok(model_validate) = conn_class.getattr("model_validate") {
944                        if let Ok(conn_obj) = model_validate.call1((payload_dict,)) {
945                            conn_obj
946                        } else {
947                            conn_class
948                                .call((), Some(payload_dict))
949                                .unwrap_or_else(|_| payload_dict.clone().into_any())
950                        }
951                    } else {
952                        conn_class
953                            .call((), Some(payload_dict))
954                            .unwrap_or_else(|_| payload_dict.clone().into_any())
955                    }
956                }
957            } else {
958                if let Ok(model_validate) = conn_class.getattr("model_validate") {
959                    if let Ok(conn_obj) = model_validate.call1((payload_dict,)) {
960                        conn_obj
961                    } else {
962                        conn_class
963                            .call((), Some(payload_dict))
964                            .unwrap_or_else(|_| payload_dict.clone().into_any())
965                    }
966                } else {
967                    conn_class
968                        .call((), Some(payload_dict))
969                        .unwrap_or_else(|_| payload_dict.clone().into_any())
970                }
971            }
972        } else {
973            payload_dict.clone().into_any()
974        };
975
976        let message_obj = if param_count >= 3 {
977            if let Some(msg_class) = message_class {
978                if let Ok(message_dict) = payload_dict.get_item("message") {
979                    if let Some(msg_dict) = message_dict {
980                        if let Ok(msg_dict) = msg_dict.downcast::<PyDict>() {
981                            if let Ok(model_validate) = msg_class.getattr("model_validate") {
982                                if let Ok(msg_obj) = model_validate.call1((msg_dict,)) {
983                                    Some(msg_obj)
984                                } else {
985                                    Some(
986                                        msg_class
987                                            .call((), Some(msg_dict))
988                                            .unwrap_or_else(|_| msg_dict.clone().into_any()),
989                                    )
990                                }
991                            } else {
992                                Some(
993                                    msg_class
994                                        .call((), Some(msg_dict))
995                                        .unwrap_or_else(|_| msg_dict.clone().into_any()),
996                                )
997                            }
998                        } else {
999                            Some(msg_dict.clone().into_any())
1000                        }
1001                    } else {
1002                        None
1003                    }
1004                } else {
1005                    None
1006                }
1007            } else {
1008                None
1009            }
1010        } else {
1011            None
1012        };
1013
1014        if param_count == 3 {
1015            if let Some(msg) = message_obj {
1016                handler_fn.call1((msg, connection_obj, state_obj))
1017            } else {
1018                handler_fn.call1((payload_dict.clone().into_any(), connection_obj, state_obj))
1019            }
1020        } else if param_count == 2 {
1021            handler_fn.call1((connection_obj, state_obj))
1022        } else {
1023            handler_fn.call1((connection_obj,))
1024        }
1025    }
1026
1027    fn extract_function_name(handler_name: &str) -> String {
1028        if handler_name.chars().any(|c| c.is_uppercase()) {
1029            let snake = to_snake_case(handler_name);
1030            format!("handle_{}", snake)
1031        } else {
1032            format!("handle_{}", handler_name.to_string())
1033        }
1034    }
1035
1036    pub async fn reload_module(&self, module_name: &str) -> Result<()> {
1037        let mut modules = self.modules.write().await;
1038        modules.remove(module_name);
1039        info!("Reloaded Python module: {}", module_name);
1040        Ok(())
1041    }
1042}
1043
1044impl Default for PythonRuntime {
1045    fn default() -> Self {
1046        Self::new().expect("Failed to initialize Python runtime")
1047    }
1048}
1049
1050fn to_snake_case(s: &str) -> String {
1051    let mut result = String::new();
1052    for (i, ch) in s.chars().enumerate() {
1053        if ch.is_uppercase() {
1054            if i > 0 {
1055                result.push('_');
1056            }
1057            result.push(ch.to_lowercase().next().unwrap());
1058        } else {
1059            result.push(ch);
1060        }
1061    }
1062    result
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use super::*;
1068
1069    #[test]
1070    fn test_function_name_extraction() {
1071        assert_eq!(
1072            PythonRuntime::extract_function_name("send_welcome_email"),
1073            "handle_send_welcome_email"
1074        );
1075        assert_eq!(
1076            PythonRuntime::extract_function_name("CreateUser"),
1077            "handle_create_user"
1078        );
1079    }
1080
1081    #[test]
1082    fn test_to_snake_case() {
1083        assert_eq!(to_snake_case("CreateUser"), "create_user");
1084        assert_eq!(to_snake_case("UserCreated"), "user_created");
1085    }
1086}