1use crate::error::{Result, RuntimeError};
2use crate::handler::{HandlerContext, HandlerResult};
3use pyo3::prelude::*;
4use pyo3::types::{PyDict, PyModule, PyTuple};
5use rohas_codegen::templates;
6use std::path::{Path, PathBuf};
7use std::sync::{Arc, Mutex};
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11#[pyclass]
12struct RohasLogFn {
13 handler_name: String,
14}
15
16#[pymethods]
17impl RohasLogFn {
18 fn __call__(
19 &self,
20 level: String,
21 handler: String,
22 message: String,
23 fields: Bound<'_, PyDict>,
24 ) -> PyResult<()> {
25 let mut field_map = std::collections::HashMap::new();
26 for (key, value) in fields.iter() {
27 let key_str = key.extract::<String>()?;
28 let value_str = format!("{:?}", value);
29 field_map.insert(key_str, value_str);
30 }
31
32 let span = tracing::span!(
33 tracing::Level::INFO,
34 "handler_log",
35 handler = %handler
36 );
37 let _enter = span.enter();
38
39 match level.as_str() {
40 "error" => tracing::error!(message = %message, ?field_map),
41 "warn" => tracing::warn!(message = %message, ?field_map),
42 "info" => tracing::info!(message = %message, ?field_map),
43 "debug" => tracing::debug!(message = %message, ?field_map),
44 "trace" => tracing::trace!(message = %message, ?field_map),
45 _ => tracing::info!(message = %message, ?field_map),
46 }
47
48 Ok(())
49 }
50}
51
52pub struct PythonRuntime {
53 modules: Arc<RwLock<std::collections::HashMap<String, Py<PyModule>>>>,
54 project_root: Arc<Mutex<Option<PathBuf>>>,
55}
56
57impl PythonRuntime {
58 pub fn new() -> Result<Self> {
59 Python::with_gil(|_| {
60 info!("Python runtime initialized");
61 });
62
63 Ok(Self {
64 modules: Arc::new(RwLock::new(std::collections::HashMap::new())),
65 project_root: Arc::new(Mutex::new(None)),
66 })
67 }
68
69 pub fn set_project_root(&mut self, root: PathBuf) {
70 let mut project_root = self.project_root.lock().unwrap();
71 *project_root = Some(root);
72 }
73
74 pub async fn execute_handler(
75 &self,
76 handler_path: &Path,
77 context: HandlerContext,
78 ) -> Result<HandlerResult> {
79 let start = std::time::Instant::now();
80 let handler_path = handler_path.to_path_buf();
81 let handler_name = context.handler_name.clone();
82 let project_root = self.project_root.lock().unwrap().clone();
83
84 debug!("Executing Python handler: {:?}", handler_path);
85
86 let task = tokio::task::spawn_blocking(move || {
87 Python::with_gil(|py| {
88 Self::execute_handler_sync(
89 py,
90 &handler_path,
91 &handler_name,
92 &context,
93 project_root.as_ref(),
94 )
95 })
96 });
97
98 let result = tokio::time::timeout(std::time::Duration::from_secs(30), task)
99 .await
100 .map_err(|_| RuntimeError::ExecutionFailed("Handler execution timeout (30s)".into()))?
101 .map_err(|e| RuntimeError::ExecutionFailed(format!("Task join error: {}", e)))??;
102
103 let execution_time_ms = start.elapsed().as_millis() as u64;
104 Ok(HandlerResult {
105 execution_time_ms,
106 ..result
107 })
108 }
109
110 fn execute_handler_sync(
111 py: Python<'_>,
112 handler_path: &Path,
113 handler_name: &str,
114 context: &HandlerContext,
115 project_root: Option<&PathBuf>,
116 ) -> Result<HandlerResult> {
117 let sys = py.import("sys")?;
118 let sys_path = sys.getattr("path")?;
119
120 if let Some(parent) = handler_path.parent() {
121 sys_path.call_method1("insert", (0, parent.to_str().unwrap()))?;
122 }
123
124 if let Some(root) = project_root {
125 let src_path = root.join("src");
126 if src_path.exists() {
127 let src_path_str = src_path.to_str().unwrap();
128 let path_list: Vec<String> = sys_path.extract()?;
129 if !path_list.contains(&src_path_str.to_string()) {
130 sys_path.call_method1("append", (src_path_str,))?;
131 debug!("Added to sys.path (appended): {:?}", src_path);
132 }
133 }
134 }
135
136 let module_name = handler_path
137 .file_stem()
138 .and_then(|s| s.to_str())
139 .ok_or_else(|| RuntimeError::ExecutionFailed("Invalid module name".into()))?;
140
141 let importlib = py.import("importlib")?;
144 let _ = importlib.call_method0("invalidate_caches");
145
146 if let Ok(modules_dict) = sys.getattr("modules") {
147 let _ = modules_dict.del_item(module_name);
148 }
149
150 let module = PyModule::import(py, module_name).map_err(|e| {
151 RuntimeError::ExecutionFailed(format!("Failed to import module: {}", e))
152 })?;
153
154 let is_event_handler = handler_path
155 .parent()
156 .and_then(|p| p.file_name())
157 .and_then(|n| n.to_str())
158 .map(|n| n == "events")
159 .unwrap_or(false);
160
161 let is_websocket_handler = handler_path
162 .parent()
163 .and_then(|p| p.file_name())
164 .and_then(|n| n.to_str())
165 .map(|n| n == "websockets")
166 .unwrap_or(false);
167
168 let is_middleware = handler_path
169 .parent()
170 .and_then(|p| p.file_name())
171 .and_then(|n| n.to_str())
172 .map(|n| n == "middlewares")
173 .unwrap_or(false);
174
175 let function_name = if is_event_handler || is_websocket_handler {
176 let direct_name = handler_name.to_string();
177 let handle_name = format!("handle_{}", handler_name);
178
179 match module.hasattr(handle_name.as_str()) {
180 Ok(true) => {
181 debug!("Using function name '{}' for handler '{}'", handle_name, handler_name);
182 handle_name
183 }
184 _ => {
185 debug!("Using function name '{}' for handler '{}' (handle_ variant not found)", direct_name, handler_name);
186 direct_name
187 }
188 }
189 } else if is_middleware {
190 format!("{}_middleware", templates::to_snake_case(handler_name))
191 } else {
192 Self::extract_function_name(handler_name)
193 };
194
195 let handler_fn = module.getattr(function_name.as_str()).map_err(|e| {
196 RuntimeError::HandlerNotFound(format!("Function '{}' not found: {}", function_name, e))
197 })?;
198
199 let inspect = py.import("inspect")?;
200 let sig = inspect.call_method1("signature", (handler_fn.as_any(),))?;
201 let params = sig.getattr("parameters")?;
202 let param_count = params.call_method0("__len__")?.extract::<usize>()?;
203
204 let state_module = py.import("generated.state")?;
205 let state_class = state_module.getattr("State")?;
206
207 let log_fn_instance = Py::new(py, RohasLogFn {
208 handler_name: handler_name.to_string(),
209 })?;
210
211 let log_fn_py: PyObject = log_fn_instance.into();
212 let state_obj = state_class.call1((handler_name, log_fn_py))?;
213 let state_obj_for_triggers = state_obj.clone();
214
215 let result = if param_count == 0 {
216 handler_fn
217 .call0()
218 .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
219 } else if is_event_handler {
220 let event_obj =
221 Self::instantiate_event_object(py, context, &handler_path).map_err(|e| {
222 RuntimeError::ExecutionFailed(format!(
223 "Failed to instantiate event object: {}",
224 e
225 ))
226 })?;
227
228 if param_count >= 2 {
229 handler_fn.call1((event_obj, state_obj)).map_err(|e| {
230 RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e))
231 })?
232 } else {
233 handler_fn.call1((event_obj,)).map_err(|e| {
234 RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e))
235 })?
236 }
237 } else if is_websocket_handler {
238 Self::call_websocket_handler(py, handler_fn, context, param_count, state_obj)
239 .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
240 } else if param_count >= 2 {
241 let request_dict = Self::build_request_dict(py, context)?;
242 let request_obj = Self::instantiate_request_class(py, handler_name, &request_dict)
243 .unwrap_or_else(|_| request_dict.clone().into_any());
244
245 handler_fn
246 .call1((request_obj, state_obj))
247 .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
248 } else {
249 let request_dict = Self::build_request_dict(py, context)?;
250 let request_obj = Self::instantiate_request_class(py, handler_name, &request_dict)
251 .unwrap_or_else(|_| request_dict.clone().into_any());
252
253 handler_fn
254 .call1((request_obj,))
255 .map_err(|e| RuntimeError::ExecutionFailed(format!("Handler call failed: {}", e)))?
256 };
257
258 let final_result = if Self::is_coroutine(py, &result)? {
259 debug!("Handler is async, awaiting coroutine");
260 Self::await_coroutine(py, result)?
261 } else {
262 result
263 };
264
265 let json_str: String = if final_result.is_none() {
266 "null".to_string()
267 } else {
268 let json_module = py.import("json")?;
269
270 let json_ready = if let Ok(model_dump) = final_result.getattr("model_dump") {
271 match model_dump.call0() {
272 Ok(dumped) => dumped,
273 Err(_) => final_result,
274 }
275 } else if let Ok(dict_method) = final_result.getattr("dict") {
276 match dict_method.call0() {
277 Ok(dumped) => dumped,
278 Err(_) => final_result,
279 }
280 } else {
281 let dataclasses = py.import("dataclasses")?;
282 if dataclasses
283 .call_method1("is_dataclass", (final_result.as_any(),))
284 .and_then(|r| r.extract::<bool>())
285 .unwrap_or(false)
286 {
287 match dataclasses.call_method1("asdict", (final_result.as_any(),)) {
288 Ok(dict) => dict,
289 Err(_) => final_result,
290 }
291 } else {
292 final_result
293 }
294 };
295
296 match json_module.call_method1("dumps", (json_ready.as_any(),)) {
297 Ok(json_result) => json_result.extract::<String>()?,
298 Err(e) => {
299 debug!("Failed to serialize response to JSON: {}, falling back to string representation", e);
300 json_ready.str()?.to_string()
301 }
302 }
303 };
304
305 let data: serde_json::Value =
306 serde_json::from_str(&json_str).unwrap_or(serde_json::json!({"raw": json_str}));
307
308 let mut result = HandlerResult::success(data, 0);
309 if param_count >= 2 || (is_event_handler && param_count >= 2) {
310 if let Ok(triggers_py) = state_obj_for_triggers.call_method0("get_triggers") {
311 if let Ok(triggers_list) = triggers_py.downcast::<pyo3::types::PyList>() {
312 for trigger_item in triggers_list.iter() {
313 let event_name_py = trigger_item.getattr("event_name");
314 let payload_py = trigger_item.getattr("payload");
315
316 if let (Ok(event_name_py), Ok(payload_py)) = (event_name_py, payload_py) {
317 if let Ok(event_name) = event_name_py.extract::<String>() {
318 let json_module = py.import("json")?;
319 if let Ok(payload_str) = json_module
320 .call_method1("dumps", (payload_py,))?
321 .extract::<String>()
322 {
323 if let Ok(payload_value) =
324 serde_json::from_str::<serde_json::Value>(&payload_str)
325 {
326 debug!(
327 "Extracted manual trigger: {} with payload",
328 event_name
329 );
330 result = result.with_trigger(event_name, payload_value);
331 } else {
332 debug!(
333 "Failed to parse payload JSON for trigger: {}",
334 event_name
335 );
336 }
337 } else {
338 debug!(
339 "Failed to serialize payload to JSON for trigger: {}",
340 event_name
341 );
342 }
343 } else {
344 debug!("Failed to extract event_name from TriggeredEvent");
345 }
346 } else {
347 debug!("Failed to get event_name or payload from TriggeredEvent");
348 }
349 }
350 } else {
351 debug!("get_triggers() did not return a list");
352 }
353 } else {
354 debug!("Failed to call get_triggers() on State object");
355 }
356
357 if let Ok(payloads_py) =
358 state_obj_for_triggers.call_method0("get_all_auto_trigger_payloads")
359 {
360 if let Ok(payloads_dict) = payloads_py.downcast::<PyDict>() {
361 for item in payloads_dict.iter() {
362 let key = item.0;
363 let value = item.1;
364 if let Ok(event_name) = key.extract::<String>() {
365 let json_module = py.import("json")?;
366 if let Ok(payload_str) = json_module
367 .call_method1("dumps", (value,))?
368 .extract::<String>()
369 {
370 if let Ok(payload_value) =
371 serde_json::from_str::<serde_json::Value>(&payload_str)
372 {
373 result =
374 result.with_auto_trigger_payload(event_name, payload_value);
375 }
376 }
377 }
378 }
379 }
380 }
381 }
382
383 Ok(result)
384 }
385
386 fn build_request_dict<'py>(
387 py: Python<'py>,
388 context: &HandlerContext,
389 ) -> PyResult<Bound<'py, PyDict>> {
390 let dict = PyDict::new(py);
391
392 let json_str = serde_json::to_string(&context.payload)
393 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
394 let json_module = py.import("json")?;
395 let payload_py = json_module.call_method1("loads", (json_str,))?;
396
397 if let Ok(payload_dict) = payload_py.downcast::<PyDict>() {
398 if !payload_dict.is_empty() {
399 if let Ok(body_value) = payload_dict.get_item("body") {
400 if let Some(body) = body_value {
401 dict.set_item("body", body)?;
402 }
403 } else {
404 for item in payload_dict.iter() {
405 let key = item.0;
406 let value = item.1;
407
408 dict.set_item(key, value)?;
409 }
410 }
411 }
412 }
413
414 let query_params_dict = PyDict::new(py);
415 for (key, value) in &context.query_params {
416 query_params_dict.set_item(key, value)?;
417 }
418 dict.set_item("query_params", query_params_dict)?;
419
420 Ok(dict)
421 }
422
423 fn instantiate_request_class<'py>(
424 py: Python<'py>,
425 handler_name: &str,
426 request_dict: &Bound<'py, PyDict>,
427 ) -> PyResult<Bound<'py, pyo3::PyAny>> {
428 let class_name = Self::handler_name_to_request_class(handler_name);
429
430 let module_name = handler_name.to_lowercase();
431
432 let import_path = format!("generated.api.{}", module_name);
433 let api_module = py.import(import_path.as_str())?;
434 let request_class = api_module.getattr(class_name.as_str())?;
435
436 request_class.call((), Some(request_dict))
437 }
438
439 fn handler_name_to_request_class(handler_name: &str) -> String {
440 let pascal_case = handler_name
441 .split('_')
442 .map(|word| {
443 let mut chars = word.chars();
444 match chars.next() {
445 None => String::new(),
446 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
447 }
448 })
449 .collect::<String>();
450
451 format!("{}Request", pascal_case)
452 }
453
454 fn is_primitive_type(type_name: &str) -> bool {
455 matches!(
456 type_name,
457 "String" | "Int" | "Float" | "Boolean" | "Bool" | "DateTime" | "Date"
458 )
459 }
460
461 fn extract_primitive_value<'py>(
462 py: Python<'py>,
463 payload_dict: &Bound<'py, pyo3::PyAny>,
464 payload_type: &str,
465 ) -> PyResult<Bound<'py, pyo3::PyAny>> {
466 if let Ok(payload_dict_ref) = payload_dict.downcast::<PyDict>() {
467 if let Ok(Some(value)) = payload_dict_ref.get_item("payload") {
468 debug!(
469 "Extracted primitive value from payload dict for type: {}",
470 payload_type
471 );
472 return Ok(value);
473 }
474 let len = payload_dict_ref.len();
475 if len == 1 {
476 if let Some((_, value)) = payload_dict_ref.iter().next() {
477 debug!(
478 "Extracted primitive value from single-key dict for type: {}",
479 payload_type
480 );
481 return Ok(value);
482 }
483 }
484 }
485 Ok(payload_dict.clone())
486 }
487
488 fn instantiate_event_object<'py>(
489 py: Python<'py>,
490 context: &HandlerContext,
491 _handler_path: &Path,
492 ) -> PyResult<Bound<'py, pyo3::PyAny>> {
493 let event_name = context.metadata.get("event_name").ok_or_else(|| {
494 pyo3::exceptions::PyValueError::new_err("Event name not found in context metadata")
495 })?;
496 let payload_type = context.metadata.get("event_payload_type");
497
498 let json_str = serde_json::to_string(&context.payload)
499 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
500 let json_module = py.import("json")?;
501 let payload_dict = json_module.call_method1("loads", (json_str,))?;
502 let payload_dict_clone = payload_dict.clone();
503
504 let datetime_module = py.import("datetime")?;
505 let now = datetime_module
506 .getattr("datetime")?
507 .getattr("now")?
508 .call0()?;
509 let now_clone = now.clone();
510
511 let convert_snake_to_camel = |dict: &Bound<'_, PyDict>| -> PyResult<Bound<'_, PyDict>> {
512 let camel_dict = PyDict::new(py);
513 for (key, value) in dict.iter() {
514 if let Ok(key_str) = key.extract::<String>() {
515 let camel_key = if key_str.contains('_') {
516 let parts: Vec<&str> = key_str.split('_').collect();
517 let mut camel = String::new();
518 for (i, part) in parts.iter().enumerate() {
519 if i == 0 {
520 camel.push_str(part);
521 } else {
522 let mut chars = part.chars();
523 if let Some(first) = chars.next() {
524 camel.push(first.to_uppercase().next().unwrap());
525 camel.push_str(&chars.as_str());
526 }
527 }
528 }
529 camel
530 } else {
531 key_str
532 };
533 camel_dict.set_item(camel_key, value)?;
534 } else {
535 camel_dict.set_item(key, value)?;
536 }
537 }
538 Ok(camel_dict)
539 };
540
541 let payload_obj = if let Some(payload_type_name) = payload_type {
542 if Self::is_primitive_type(payload_type_name) {
543 debug!(
544 "Payload type {} is primitive, extracting value",
545 payload_type_name
546 );
547 Self::extract_primitive_value(py, &payload_dict, payload_type_name)?
548 } else {
549 let payload_type_snake = templates::to_snake_case(payload_type_name);
550 let model_module_path = format!("generated.models.{}", payload_type_snake);
551
552 match py.import(&model_module_path) {
553 Ok(model_module) => match model_module.getattr(payload_type_name.as_str()) {
554 Ok(model_class) => {
555 if let Ok(payload_dict_ref) = payload_dict.downcast::<PyDict>() {
556 match convert_snake_to_camel(&payload_dict_ref) {
557 Ok(camel_payload_dict) => {
558 match model_class.call((), Some(&camel_payload_dict)) {
559 Ok(model_obj) => {
560 debug!(
561 "Successfully instantiated payload model: {}",
562 payload_type_name
563 );
564 model_obj.into_any()
565 }
566 Err(e) => {
567 debug!("Direct call failed for {}, trying model_validate: {}", payload_type_name, e);
568 if let Ok(model_validate) =
569 model_class.getattr("model_validate")
570 {
571 match model_validate
572 .call1((camel_payload_dict,))
573 {
574 Ok(model_obj) => {
575 debug!("Successfully instantiated payload model via model_validate: {}", payload_type_name);
576 model_obj.into_any()
577 }
578 Err(e2) => {
579 debug!("model_validate also failed for {}: {}", payload_type_name, e2);
580 payload_dict.clone()
581 }
582 }
583 } else {
584 payload_dict.clone()
585 }
586 }
587 }
588 }
589 Err(e) => {
590 debug!("Failed to convert field names: {}, trying with original dict", e);
591 match model_class.call((), Some(&payload_dict_ref)) {
592 Ok(model_obj) => {
593 debug!("Successfully instantiated payload model with original dict: {}", payload_type_name);
594 model_obj.into_any()
595 }
596 Err(_) => payload_dict.clone(),
597 }
598 }
599 }
600 } else {
601 if let Ok(model_validate) = model_class.getattr("model_validate") {
602 match model_validate.call1((payload_dict.clone(),)) {
603 Ok(model_obj) => {
604 debug!("Successfully instantiated payload model via model_validate (PyAny): {}", payload_type_name);
605 model_obj.into_any()
606 }
607 Err(e) => {
608 debug!(
609 "model_validate failed for {} (PyAny): {}",
610 payload_type_name, e
611 );
612 payload_dict.clone()
613 }
614 }
615 } else {
616 payload_dict.clone()
617 }
618 }
619 }
620 Err(e) => {
621 debug!(
622 "Failed to get model class {} from module: {}",
623 payload_type_name, e
624 );
625 payload_dict.clone()
626 }
627 },
628 Err(e) => {
629 debug!("Failed to import model module {}: {}", model_module_path, e);
630 payload_dict.clone()
631 }
632 }
633 }
634 } else {
635 debug!("No payload type in metadata, using dict as-is");
636 payload_dict.clone()
637 };
638
639 let event_name_snake = templates::to_snake_case(event_name);
640 let event_module_path = format!("generated.events.{}", event_name_snake);
641
642 match py.import(&event_module_path) {
643 Ok(event_module) => match event_module.getattr(event_name.as_str()) {
644 Ok(event_class) => {
645 let event_dict = PyDict::new(py);
646 event_dict.set_item("payload", payload_obj)?;
647 event_dict.set_item("timestamp", &now)?;
648
649 let mut event_dict_for_direct = None;
650 if let Ok(model_validate) = event_class.getattr("model_validate") {
651 debug!("Attempting model_validate for event {}", event_name);
652 match model_validate.call1((event_dict,)) {
653 Ok(event_obj) => {
654 debug!("model_validate call succeeded for event {}", event_name);
655 match event_obj.getattr("payload") {
656 Ok(_) => {
657 debug!("Event object has payload attribute - instantiation successful via model_validate");
658 return Ok(event_obj);
659 }
660 Err(e) => {
661 debug!("Event object from model_validate missing payload attribute: {}", e);
662 }
663 }
664 }
665 Err(e) => {
666 let error_msg = format!("{}", e);
667 debug!(
668 "model_validate failed for event {}: {}",
669 event_name, error_msg
670 );
671 let py_err = e.value(py);
672 if let Ok(err_str) = py_err.str() {
673 debug!("Python error details: {}", err_str.to_string_lossy());
674 }
675 let json_str_direct = serde_json::to_string(&context.payload)
676 .map_err(|e| {
677 pyo3::exceptions::PyValueError::new_err(e.to_string())
678 })?;
679 let payload_for_direct =
680 json_module.call_method1("loads", (json_str_direct,))?;
681 let payload_for_direct_value =
682 if let Some(payload_type_name) = payload_type {
683 if Self::is_primitive_type(payload_type_name) {
684 Self::extract_primitive_value(
685 py,
686 &payload_for_direct,
687 payload_type_name,
688 )?
689 } else {
690 payload_for_direct
691 }
692 } else {
693 payload_for_direct
694 };
695 let event_dict2 = PyDict::new(py);
696 event_dict2.set_item("payload", payload_for_direct_value)?;
697 event_dict2.set_item("timestamp", &now_clone)?;
698 event_dict_for_direct = Some(event_dict2);
699 }
700 }
701 } else {
702 debug!(
703 "model_validate method not found for event {}, using direct call",
704 event_name
705 );
706 event_dict_for_direct = Some(event_dict);
707 }
708
709 if let Some(event_dict2) = event_dict_for_direct {
710 debug!("Attempting direct call for event {}", event_name);
711 match event_class.call((), Some(&event_dict2)) {
712 Ok(event_obj) => {
713 debug!("Direct call succeeded for event {}", event_name);
714 match event_obj.getattr("payload") {
715 Ok(_) => {
716 debug!("Event object has payload attribute - instantiation successful via direct call");
717 return Ok(event_obj);
718 }
719 Err(e) => {
720 debug!("Event object missing payload attribute: {}", e);
721 }
722 }
723 }
724 Err(e) => {
725 let error_msg = format!("{}", e);
726 debug!(
727 "Direct call also failed for event {}, error: {}",
728 event_name, error_msg
729 );
730 let py_err = e.value(py);
731 if let Ok(err_str) = py_err.str() {
732 debug!("Python error details: {}", err_str.to_string_lossy());
733 }
734 }
735 }
736 }
737 }
738 Err(e) => {
739 debug!(
740 "Failed to get event class {} from module: {}",
741 event_name, e
742 );
743 }
744 },
745 Err(e) => {
746 debug!("Failed to import event module {}: {}", event_module_path, e);
747 }
748 }
749
750 debug!(
751 "Attempting final fallback instantiation for event: {} with dict payload",
752 event_name
753 );
754 let final_payload_value = if let Some(payload_type_name) = payload_type {
755 if Self::is_primitive_type(payload_type_name) {
756 Self::extract_primitive_value(py, &payload_dict_clone, payload_type_name)?
757 } else {
758 payload_dict_clone
759 }
760 } else {
761 payload_dict_clone
762 };
763 let final_event_dict = PyDict::new(py);
764 final_event_dict.set_item("payload", final_payload_value)?;
765 final_event_dict.set_item("timestamp", &now_clone)?;
766
767 if let Ok(event_module) = py.import(&event_module_path) {
768 if let Ok(event_class) = event_module.getattr(event_name.as_str()) {
769 let mut final_dict_for_direct = None;
770 if let Ok(model_validate) = event_class.getattr("model_validate") {
771 match model_validate.call1((final_event_dict,)) {
772 Ok(event_obj) => {
773 if event_obj.getattr("payload").is_ok() {
774 debug!("Successfully instantiated event via fallback model_validate with dict payload");
775 return Ok(event_obj);
776 }
777 }
778 Err(e) => {
779 debug!("Fallback model_validate failed: {}", e);
780 let json_str_fallback2 = serde_json::to_string(&context.payload)
781 .map_err(|e| {
782 pyo3::exceptions::PyValueError::new_err(e.to_string())
783 })?;
784 let payload_dict_fallback2 =
785 json_module.call_method1("loads", (json_str_fallback2,))?;
786 let payload_fallback2_value =
787 if let Some(payload_type_name) = payload_type {
788 if Self::is_primitive_type(payload_type_name) {
789 Self::extract_primitive_value(
790 py,
791 &payload_dict_fallback2,
792 payload_type_name,
793 )?
794 } else {
795 payload_dict_fallback2
796 }
797 } else {
798 payload_dict_fallback2
799 };
800 let final_dict2 = PyDict::new(py);
801 final_dict2.set_item("payload", payload_fallback2_value)?;
802 final_dict2.set_item("timestamp", &now_clone)?;
803 final_dict_for_direct = Some(final_dict2);
804 }
805 }
806 } else {
807 final_dict_for_direct = Some(final_event_dict);
808 }
809 if let Some(final_dict2) = final_dict_for_direct {
810 match event_class.call((), Some(&final_dict2)) {
811 Ok(event_obj) => {
812 if event_obj.getattr("payload").is_ok() {
813 debug!("Successfully instantiated event via fallback direct call with dict payload");
814 return Ok(event_obj);
815 }
816 }
817 Err(e) => {
818 debug!("Fallback direct call also failed: {}", e);
819 }
820 }
821 }
822 }
823 }
824
825 Err(pyo3::exceptions::PyValueError::new_err(format!(
826 "Failed to instantiate event object {}: All instantiation methods failed. Check debug logs for details.",
827 event_name
828 )))
829 }
830
831 fn context_to_py_dict<'py>(
832 py: Python<'py>,
833 context: &HandlerContext,
834 ) -> PyResult<Bound<'py, PyDict>> {
835 let dict = PyDict::new(py);
836
837 dict.set_item("handler_name", &context.handler_name)?;
838
839 let json_str = serde_json::to_string(&context.payload)
840 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
841 let json_module = py.import("json")?;
842 let payload_py = json_module.call_method1("loads", (json_str,))?;
843 dict.set_item("payload", payload_py)?;
844
845 let query_params_dict = PyDict::new(py);
846 for (key, value) in &context.query_params {
847 query_params_dict.set_item(key, value)?;
848 }
849 dict.set_item("query_params", query_params_dict)?;
850
851 dict.set_item("timestamp", &context.timestamp)?;
852
853 let metadata_dict = PyDict::new(py);
854 for (key, value) in &context.metadata {
855 metadata_dict.set_item(key, value)?;
856 }
857 dict.set_item("metadata", metadata_dict)?;
858
859 Ok(dict)
860 }
861
862 fn is_coroutine(py: Python<'_>, obj: &Bound<'_, pyo3::PyAny>) -> PyResult<bool> {
863 let inspect = py.import("inspect")?;
864 let is_coro = inspect.call_method1("iscoroutine", (obj,))?;
865 is_coro.extract::<bool>()
866 }
867
868 fn await_coroutine<'py>(
869 py: Python<'py>,
870 coro: Bound<'py, pyo3::PyAny>,
871 ) -> PyResult<Bound<'py, pyo3::PyAny>> {
872 let asyncio = py.import("asyncio")?;
873
874 let loop_result = asyncio.call_method0("get_event_loop");
875
876 let result = if let Ok(event_loop) = loop_result {
877 event_loop.call_method1("run_until_complete", (coro,))?
878 } else {
879 let new_loop = asyncio.call_method0("new_event_loop")?;
880 asyncio.call_method1("set_event_loop", (new_loop.as_any(),))?;
881 let result = new_loop.call_method1("run_until_complete", (coro,))?;
882 new_loop.call_method0("close")?;
883 result
884 };
885
886 Ok(result)
887 }
888
889 fn call_websocket_handler<'py>(
890 py: Python<'py>,
891 handler_fn: Bound<'py, pyo3::PyAny>,
892 context: &HandlerContext,
893 param_count: usize,
894 state_obj: Bound<'py, pyo3::PyAny>,
895 ) -> PyResult<Bound<'py, pyo3::PyAny>> {
896 let json_str = serde_json::to_string(&context.payload)
897 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
898 let json_module = py.import("json")?;
899 let payload_dict = json_module.call_method1("loads", (json_str,))?;
900 let payload_dict = payload_dict.downcast::<PyDict>().map_err(|e| {
901 pyo3::exceptions::PyValueError::new_err(format!("Payload is not a dict: {}", e))
902 })?;
903
904 let ws_name = context
905 .metadata
906 .get("websocket_name")
907 .map(|s| s.as_str())
908 .unwrap_or("HelloWorld");
909
910 let ws_name_snake = templates::to_snake_case(ws_name);
911 let ws_module_path = format!("generated.websockets.{}", ws_name_snake);
912
913 let (connection_class, message_class) = match py.import(&ws_module_path) {
914 Ok(ws_module) => {
915 let conn_class = ws_module.getattr(&format!("{}Connection", ws_name)).ok();
916 let msg_class = ws_module.getattr(&format!("{}Message", ws_name)).ok();
917 (conn_class, msg_class)
918 }
919 Err(_) => (None, None),
920 };
921
922 let connection_obj = if let Some(conn_class) = connection_class {
923 if let Ok(connection_dict) = payload_dict.get_item("connection") {
924 if let Some(conn_dict) = connection_dict {
925 if let Ok(conn_dict) = conn_dict.downcast::<PyDict>() {
926 if let Ok(model_validate) = conn_class.getattr("model_validate") {
927 if let Ok(conn_obj) = model_validate.call1((conn_dict,)) {
928 conn_obj
929 } else {
930 conn_class
931 .call((), Some(conn_dict))
932 .unwrap_or_else(|_| conn_dict.clone().into_any())
933 }
934 } else {
935 conn_class
936 .call((), Some(conn_dict))
937 .unwrap_or_else(|_| conn_dict.clone().into_any())
938 }
939 } else {
940 conn_dict.clone().into_any()
941 }
942 } else {
943 if let Ok(model_validate) = conn_class.getattr("model_validate") {
944 if let Ok(conn_obj) = model_validate.call1((payload_dict,)) {
945 conn_obj
946 } else {
947 conn_class
948 .call((), Some(payload_dict))
949 .unwrap_or_else(|_| payload_dict.clone().into_any())
950 }
951 } else {
952 conn_class
953 .call((), Some(payload_dict))
954 .unwrap_or_else(|_| payload_dict.clone().into_any())
955 }
956 }
957 } else {
958 if let Ok(model_validate) = conn_class.getattr("model_validate") {
959 if let Ok(conn_obj) = model_validate.call1((payload_dict,)) {
960 conn_obj
961 } else {
962 conn_class
963 .call((), Some(payload_dict))
964 .unwrap_or_else(|_| payload_dict.clone().into_any())
965 }
966 } else {
967 conn_class
968 .call((), Some(payload_dict))
969 .unwrap_or_else(|_| payload_dict.clone().into_any())
970 }
971 }
972 } else {
973 payload_dict.clone().into_any()
974 };
975
976 let message_obj = if param_count >= 3 {
977 if let Some(msg_class) = message_class {
978 if let Ok(message_dict) = payload_dict.get_item("message") {
979 if let Some(msg_dict) = message_dict {
980 if let Ok(msg_dict) = msg_dict.downcast::<PyDict>() {
981 if let Ok(model_validate) = msg_class.getattr("model_validate") {
982 if let Ok(msg_obj) = model_validate.call1((msg_dict,)) {
983 Some(msg_obj)
984 } else {
985 Some(
986 msg_class
987 .call((), Some(msg_dict))
988 .unwrap_or_else(|_| msg_dict.clone().into_any()),
989 )
990 }
991 } else {
992 Some(
993 msg_class
994 .call((), Some(msg_dict))
995 .unwrap_or_else(|_| msg_dict.clone().into_any()),
996 )
997 }
998 } else {
999 Some(msg_dict.clone().into_any())
1000 }
1001 } else {
1002 None
1003 }
1004 } else {
1005 None
1006 }
1007 } else {
1008 None
1009 }
1010 } else {
1011 None
1012 };
1013
1014 if param_count == 3 {
1015 if let Some(msg) = message_obj {
1016 handler_fn.call1((msg, connection_obj, state_obj))
1017 } else {
1018 handler_fn.call1((payload_dict.clone().into_any(), connection_obj, state_obj))
1019 }
1020 } else if param_count == 2 {
1021 handler_fn.call1((connection_obj, state_obj))
1022 } else {
1023 handler_fn.call1((connection_obj,))
1024 }
1025 }
1026
1027 fn extract_function_name(handler_name: &str) -> String {
1028 if handler_name.chars().any(|c| c.is_uppercase()) {
1029 let snake = to_snake_case(handler_name);
1030 format!("handle_{}", snake)
1031 } else {
1032 format!("handle_{}", handler_name.to_string())
1033 }
1034 }
1035
1036 pub async fn reload_module(&self, module_name: &str) -> Result<()> {
1037 let mut modules = self.modules.write().await;
1038 modules.remove(module_name);
1039 info!("Reloaded Python module: {}", module_name);
1040 Ok(())
1041 }
1042}
1043
1044impl Default for PythonRuntime {
1045 fn default() -> Self {
1046 Self::new().expect("Failed to initialize Python runtime")
1047 }
1048}
1049
1050fn to_snake_case(s: &str) -> String {
1051 let mut result = String::new();
1052 for (i, ch) in s.chars().enumerate() {
1053 if ch.is_uppercase() {
1054 if i > 0 {
1055 result.push('_');
1056 }
1057 result.push(ch.to_lowercase().next().unwrap());
1058 } else {
1059 result.push(ch);
1060 }
1061 }
1062 result
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use super::*;
1068
1069 #[test]
1070 fn test_function_name_extraction() {
1071 assert_eq!(
1072 PythonRuntime::extract_function_name("send_welcome_email"),
1073 "handle_send_welcome_email"
1074 );
1075 assert_eq!(
1076 PythonRuntime::extract_function_name("CreateUser"),
1077 "handle_create_user"
1078 );
1079 }
1080
1081 #[test]
1082 fn test_to_snake_case() {
1083 assert_eq!(to_snake_case("CreateUser"), "create_user");
1084 assert_eq!(to_snake_case("UserCreated"), "user_created");
1085 }
1086}