Skip to main content

rohas_codegen/
python.rs

1use crate::error::Result;
2use crate::templates;
3use rohas_parser::{Api, Event, FieldType, Model, Schema, WebSocket};
4use std::fs;
5use std::path::Path;
6
7pub fn generate_models(schema: &Schema, output_dir: &Path) -> Result<()> {
8    let models_dir = output_dir.join("generated/models");
9
10    for model in &schema.models {
11        let content = generate_model_content(model);
12        let file_name = format!("{}.py", templates::to_snake_case(&model.name));
13        fs::write(models_dir.join(file_name), content)?;
14    }
15
16    Ok(())
17}
18
19fn generate_model_content(model: &Model) -> String {
20    let mut content = String::new();
21
22    content.push_str("from pydantic import BaseModel\n");
23    content.push_str("from typing import Optional\n");
24    content.push_str("from datetime import datetime\n\n");
25
26    content.push_str(&format!("class {}(BaseModel):\n", model.name));
27
28    for field in &model.fields {
29        let py_type = field.field_type.to_python();
30        let type_hint = if field.optional {
31            format!("Optional[{}]", py_type)
32        } else {
33            py_type
34        };
35        content.push_str(&format!("    {}: {}\n", field.name, type_hint));
36    }
37
38    if model.fields.is_empty() {
39        content.push_str("    pass\n");
40    }
41
42    content.push_str("\n    class Config:\n");
43    content.push_str("        from_attributes = True\n");
44
45    content
46}
47
48pub fn generate_dtos(schema: &Schema, output_dir: &Path) -> Result<()> {
49    let dto_dir = output_dir.join("generated/dto");
50
51    for input in &schema.inputs {
52        let content = generate_model_content(&rohas_parser::Model {
53            name: input.name.clone(),
54            fields: input.fields.clone(),
55            attributes: vec![],
56        });
57        let file_name = format!("{}.py", templates::to_snake_case(&input.name));
58        fs::write(dto_dir.join(file_name), content)?;
59    }
60
61    Ok(())
62}
63
64pub fn generate_apis(schema: &Schema, output_dir: &Path) -> Result<()> {
65    let api_dir = output_dir.join("generated/api");
66
67    for api in &schema.apis {
68        let content = generate_api_content(api);
69        let file_name = format!("{}.py", templates::to_snake_case(&api.name));
70        fs::write(api_dir.join(file_name), content)?;
71    }
72
73    let handlers_dir = output_dir.join("handlers/api");
74    for api in &schema.apis {
75        let file_name = format!("{}.py", templates::to_snake_case(&api.name));
76        let handler_path = handlers_dir.join(&file_name);
77
78        if !handler_path.exists() {
79            let content = generate_api_handler_stub(api);
80            fs::write(handler_path, content)?;
81        }
82    }
83
84    Ok(())
85}
86
87/// Extract path parameters from a path string
88/// e.g., "/test/{name}" -> ["name"]
89/// e.g., "/users/{id}/posts/{postId}" -> ["id", "postId"]
90fn extract_path_params(path: &str) -> Vec<String> {
91    let mut params = Vec::new();
92    let mut in_param = false;
93    let mut current_param = String::new();
94
95    for ch in path.chars() {
96        match ch {
97            '{' => {
98                in_param = true;
99                current_param.clear();
100            }
101            '}' => {
102                if in_param && !current_param.is_empty() {
103                    params.push(current_param.clone());
104                }
105                in_param = false;
106            }
107            _ if in_param => {
108                current_param.push(ch);
109            }
110            _ => {}
111        }
112    }
113
114    params
115}
116
117fn generate_api_content(api: &Api) -> String {
118    let mut content = String::new();
119
120    content.push_str("from pydantic import BaseModel\n");
121    content.push_str("from typing import Callable, Awaitable, Dict, Optional\n");
122
123    let response_field_type = FieldType::from_str(&api.response);
124    let response_py_type = response_field_type.to_python();
125
126    let is_custom_type = matches!(response_field_type, FieldType::Custom(_));
127    if is_custom_type {
128        content.push_str(&format!(
129            "from ..models.{} import {}\n",
130            templates::to_snake_case(&api.response),
131            api.response
132        ));
133    }
134
135    if let Some(body) = &api.body {
136        content.push_str(&format!(
137            "from ..dto.{} import {}\n",
138            templates::to_snake_case(body),
139            body
140        ));
141    }
142
143    let path_params = extract_path_params(&api.path);
144
145    content.push_str(&format!("\nclass {}Request(BaseModel):\n", api.name));
146
147    for param in &path_params {
148        content.push_str(&format!("    {}: str\n", param));
149    }
150
151    if let Some(body) = &api.body {
152        content.push_str(&format!("    body: {}\n", body));
153    }
154
155    content.push_str("    query_params: Dict[str, str] = {}\n");
156
157    if path_params.is_empty() && api.body.is_none() {
158        // We still have query_params, so no pass needed
159    }
160
161    content.push_str("\n    class Config:\n");
162    content.push_str("        from_attributes = True\n");
163
164    content.push_str(&format!("\nclass {}Response(BaseModel):\n", api.name));
165    content.push_str(&format!("    data: {}\n", response_py_type));
166
167    content.push_str("\n    class Config:\n");
168    content.push_str("        from_attributes = True\n");
169
170    content.push_str(&format!(
171        "\n{}Handler = Callable[[{}Request], Awaitable[{}Response]]\n",
172        api.name, api.name, api.name
173    ));
174
175    content
176}
177
178fn generate_api_handler_stub(api: &Api) -> String {
179    let mut content = String::new();
180
181    content.push_str(&format!(
182        "from generated.api.{} import {}Request, {}Response\n",
183        templates::to_snake_case(&api.name),
184        api.name,
185        api.name
186    ));
187    content.push_str("from generated.state import State\n\n");
188
189    content.push_str(&format!(
190        "async def handle_{}(req: {}Request, state: State) -> {}Response:\n",
191        templates::to_snake_case(&api.name),
192        api.name,
193        api.name
194    ));
195    content.push_str("    # TODO: Implement handler logic\n");
196    content.push_str("    # For auto-triggers (defined in schema triggers): use state.set_payload('EventName', {...})\n");
197    content.push_str("    # For manual triggers: use state.trigger_event('EventName', {...})\n");
198    content.push_str("    raise NotImplementedError('Handler not implemented')\n");
199
200    content
201}
202
203pub fn generate_events(schema: &Schema, output_dir: &Path) -> Result<()> {
204    let events_dir = output_dir.join("generated/events");
205
206    for event in &schema.events {
207        let content = generate_event_content(event);
208        let file_name = format!("{}.py", templates::to_snake_case(&event.name));
209        fs::write(events_dir.join(file_name), content)?;
210    }
211
212    let handlers_dir = output_dir.join("handlers/events");
213    for event in &schema.events {
214        for handler in &event.handlers {
215            let file_name = format!("{}.py", handler);
216            let handler_path = handlers_dir.join(&file_name);
217
218            if !handler_path.exists() {
219                let content = generate_event_handler_stub(event, handler);
220                fs::write(handler_path, content)?;
221            }
222        }
223    }
224
225    Ok(())
226}
227
228fn generate_event_content(event: &Event) -> String {
229    let mut content = String::new();
230
231    content.push_str("from pydantic import BaseModel\n");
232    content.push_str("from datetime import datetime\n");
233    content.push_str("from typing import Callable, Awaitable\n");
234
235    let payload_field_type = FieldType::from_str(&event.payload);
236    let payload_py_type = payload_field_type.to_python();
237
238    let is_custom_type = matches!(payload_field_type, FieldType::Custom(_));
239    if is_custom_type {
240        content.push_str(&format!(
241            "from ..models.{} import {}\n",
242            templates::to_snake_case(&event.payload),
243            event.payload
244        ));
245    }
246
247    content.push_str(&format!("\nclass {}(BaseModel):\n", event.name));
248    content.push_str(&format!("    payload: {}\n", payload_py_type));
249    content.push_str("    timestamp: datetime\n\n");
250
251    content.push_str("    class Config:\n");
252    content.push_str("        from_attributes = True\n\n");
253
254    content.push_str(&format!(
255        "{}Handler = Callable[[{}], Awaitable[None]]\n",
256        event.name, event.name
257    ));
258
259    content
260}
261
262fn generate_event_handler_stub(event: &Event, handler_name: &str) -> String {
263    let mut content = String::new();
264
265    content.push_str(&format!(
266        "from generated.events.{} import {}\n\n",
267        templates::to_snake_case(&event.name),
268        event.name
269    ));
270
271    content.push_str(&format!(
272        "async def {}(event: {}) -> None:\n",
273        handler_name, event.name
274    ));
275    content.push_str("    # TODO: Implement event handler\n");
276    content.push_str(&format!("    print(f'Handling event: {{event}}')\n"));
277
278    content
279}
280
281pub fn generate_crons(schema: &Schema, output_dir: &Path) -> Result<()> {
282    let handlers_dir = output_dir.join("handlers/cron");
283
284    for cron in &schema.crons {
285        let file_name = format!("{}.py", templates::to_snake_case(&cron.name));
286        let handler_path = handlers_dir.join(&file_name);
287
288        if !handler_path.exists() {
289            let content = format!(
290                "async def handle_{}() -> None:\n    # TODO: Implement cron job\n    print('Running cron: {}')\n",
291                templates::to_snake_case(&cron.name),
292                cron.name
293            );
294            fs::write(handler_path, content)?;
295        }
296    }
297
298    Ok(())
299}
300
301pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
302    let ws_dir = output_dir.join("generated/websockets");
303
304    for ws in &schema.websockets {
305        let content = generate_websocket_content(ws);
306        let file_name = format!("{}.py", templates::to_snake_case(&ws.name));
307        fs::write(ws_dir.join(file_name), content)?;
308    }
309
310    let handlers_dir = output_dir.join("handlers/websockets");
311    for ws in &schema.websockets {
312        if !ws.on_connect.is_empty() {
313            for handler in &ws.on_connect {
314                let file_name = format!("{}.py", handler);
315                let handler_path = handlers_dir.join(&file_name);
316                if !handler_path.exists() {
317                    let content = generate_websocket_handler_stub(ws, "onConnect", handler);
318                    fs::write(handler_path, content)?;
319                }
320            }
321        }
322        if !ws.on_message.is_empty() {
323            for handler in &ws.on_message {
324                let file_name = format!("{}.py", handler);
325                let handler_path = handlers_dir.join(&file_name);
326                if !handler_path.exists() {
327                    let content = generate_websocket_handler_stub(ws, "onMessage", handler);
328                    fs::write(handler_path, content)?;
329                }
330            }
331        }
332        if !ws.on_disconnect.is_empty() {
333            for handler in &ws.on_disconnect {
334                let file_name = format!("{}.py", handler);
335                let handler_path = handlers_dir.join(&file_name);
336                if !handler_path.exists() {
337                    let content = generate_websocket_handler_stub(ws, "onDisconnect", handler);
338                    fs::write(handler_path, content)?;
339                }
340            }
341        }
342    }
343
344    Ok(())
345}
346
347pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
348    use std::collections::HashSet;
349   
350    let mut middleware_names = HashSet::new();
351    
352    for api in &schema.apis {
353        for middleware in &api.middlewares {
354            middleware_names.insert(middleware.clone());
355        }
356    }
357    
358    for ws in &schema.websockets {
359        for middleware in &ws.middlewares {
360            middleware_names.insert(middleware.clone());
361        }
362    }
363    
364    if middleware_names.is_empty() {
365        return Ok(());
366    }
367    
368    let middlewares_dir = output_dir.join("middlewares");
369    for middleware_name in middleware_names {
370        let file_name = format!("{}.py", templates::to_snake_case(&middleware_name));
371        let middleware_path = middlewares_dir.join(&file_name);
372        
373        if !middleware_path.exists() {
374            let content = generate_middleware_stub(&middleware_name);
375            fs::write(middleware_path, content)?;
376        }
377    }
378    
379    Ok(())
380}
381
382fn generate_middleware_stub(middleware_name: &str) -> String {
383    let mut content = String::new();
384    
385    content.push_str("from typing import Dict, Any, Optional\n");
386    content.push_str("from generated.state import State\n\n");
387    
388    content.push_str(&format!(
389        "async def {}_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]:\n",
390        templates::to_snake_case(middleware_name)
391    ));
392    content.push_str("    \"\"\"\n");
393    content.push_str(&format!("    Middleware function for {}.\n\n", middleware_name));
394    content.push_str("    Args:\n");
395    content.push_str("        context: Request context containing:\n");
396    content.push_str("            - payload: Request payload (for APIs)\n");
397    content.push_str("            - query_params: Query parameters (for APIs)\n");
398    content.push_str("            - connection: WebSocket connection info (for WebSockets)\n");
399    content.push_str("            - websocket_name: WebSocket name (for WebSockets)\n");
400    content.push_str("            - api_name: API name (for APIs)\n");
401    content.push_str("            - trace_id: Trace ID\n");
402    content.push_str("        state: State object for logging and triggering events\n\n");
403    content.push_str("    Returns:\n");
404    content.push_str("        Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys,\n");
405    content.push_str("        or None to pass through unchanged. Return a dict with 'error' key to reject the request.\n\n");
406    content.push_str("    To reject the request, raise an exception \n");
407    content.push_str("    \"\"\"\n");
408    content.push_str("    # TODO: Implement middleware logic\n");
409    content.push_str("    # Example: Validate authentication\n");
410    content.push_str("    # Example: Rate limiting\n");
411    content.push_str("    # Example: Logging\n");
412    content.push_str("    # Example: Modify payload/query_params\n");
413    content.push_str("    # \n");
414    content.push_str("    # To modify the request:\n");
415    content.push_str("    # return {\n");
416    content.push_str("    #     'payload': modified_payload,\n");
417    content.push_str("    #     'query_params': modified_query_params\n");
418    content.push_str("    # }\n");
419    content.push_str("    # \n");
420    content.push_str("    # To reject the request:\n");
421    content.push_str("    # raise Exception('Access denied')\n");
422    content.push_str("    \n");
423    content.push_str("    # Pass through unchanged\n");
424    content.push_str("    return None\n");
425    
426    content
427}
428
429fn generate_websocket_content(ws: &WebSocket) -> String {
430    let mut content = String::new();
431
432    content.push_str("from pydantic import BaseModel\n");
433    content.push_str("from typing import Dict, Any, Optional\n");
434    content.push_str("from datetime import datetime\n\n");
435
436    if let Some(message_type) = &ws.message {
437        let message_field_type = FieldType::from_str(message_type);
438        let is_custom_type = matches!(message_field_type, FieldType::Custom(_));
439        if is_custom_type {
440            content.push_str(&format!(
441                "from ..dto.{} import {}\n",
442                templates::to_snake_case(message_type),
443                message_type
444            ));
445        }
446    }
447
448    content.push_str(&format!("class {}Message(BaseModel):\n", ws.name));
449    if let Some(message_type) = &ws.message {
450        let message_field_type = FieldType::from_str(message_type);
451        let py_type = message_field_type.to_python();
452        content.push_str(&format!("    data: {}\n", py_type));
453    } else {
454        content.push_str("    data: Dict[str, Any]\n");
455    }
456    content.push_str("    timestamp: datetime\n\n");
457    content.push_str("    class Config:\n");
458    content.push_str("        from_attributes = True\n\n");
459
460    content.push_str(&format!("class {}Connection(BaseModel):\n", ws.name));
461    content.push_str("    connection_id: str\n");
462    content.push_str("    path: str\n");
463    content.push_str("    connected_at: datetime\n\n");
464    content.push_str("    class Config:\n");
465    content.push_str("        from_attributes = True\n");
466
467    content
468}
469
470fn generate_websocket_handler_stub(
471    ws: &WebSocket,
472    handler_type: &str,
473    handler_name: &str,
474) -> String {
475    let mut content = String::new();
476
477    content.push_str(&format!(
478        "from generated.websockets.{} import {}Message, {}Connection\n",
479        templates::to_snake_case(&ws.name),
480        ws.name,
481        ws.name
482    ));
483    content.push_str("from generated.state import State\n");
484    content.push_str("from typing import Optional\n\n");
485
486    match handler_type {
487        "onConnect" => {
488            content.push_str(&format!(
489                "async def {}(connection: {}Connection, state: State) -> Optional[{}Message]:\n",
490                handler_name, ws.name, ws.name
491            ));
492            content.push_str("    # TODO: Implement onConnect handler\n");
493            content
494                .push_str("    # Return a message to send to the client on connection, or None\n");
495            content.push_str(&format!(
496                "    print(f'Client connected: {{connection.connection_id}}')\n"
497            ));
498            content.push_str("    return None\n");
499        }
500        "onMessage" => {
501            content.push_str(&format!(
502                "async def {}(message: {}Message, connection: {}Connection, state: State) -> Optional[{}Message]:\n",
503                handler_name,
504                ws.name,
505                ws.name,
506                ws.name
507            ));
508            content.push_str("    # TODO: Implement onMessage handler\n");
509            content.push_str("    # Return a message to send back to the client, or None\n");
510            content.push_str(&format!(
511                "    print(f'Received message: {{message.data}}')\n"
512            ));
513            content.push_str("    # For auto-triggers (defined in schema triggers): use state.set_payload('EventName', {...})\n");
514            content.push_str(
515                "    # For manual triggers: use state.trigger_event('EventName', {...})\n",
516            );
517            content.push_str("    return None\n");
518        }
519        "onDisconnect" => {
520            content.push_str(&format!(
521                "async def {}(connection: {}Connection, state: State) -> None:\n",
522                handler_name, ws.name
523            ));
524            content.push_str("    # TODO: Implement onDisconnect handler\n");
525            content.push_str(&format!(
526                "    print(f'Client disconnected: {{connection.connection_id}}')\n"
527            ));
528        }
529        _ => {}
530    }
531
532    content
533}
534
535pub fn generate_state(output_dir: &Path) -> Result<()> {
536    let generated_dir = output_dir.join("generated");
537    let content = r#"from typing import Any, Dict, List, Optional
538from pydantic import BaseModel
539
540
541class TriggeredEvent(BaseModel):
542    event_name: str
543    payload: Dict[str, Any]
544
545
546class Logger:
547    """Logger for handlers to emit structured logs."""
548    
549    def __init__(self, handler_name: str, log_fn: Any):
550        self._handler_name = handler_name
551        self._log_fn = log_fn
552    
553    def info(self, message: str, **kwargs: Any) -> None:
554        """Log an info message.
555        
556        Args:
557            message: Log message
558            **kwargs: Additional fields to include in the log
559        """
560        if self._log_fn:
561            self._log_fn("info", self._handler_name, message, kwargs)
562    
563    def error(self, message: str, **kwargs: Any) -> None:
564        """Log an error message.
565        
566        Args:
567            message: Log message
568            **kwargs: Additional fields to include in the log
569        """
570        if self._log_fn:
571            self._log_fn("error", self._handler_name, message, kwargs)
572    
573    def warning(self, message: str, **kwargs: Any) -> None:
574        """Log a warning message.
575        
576        Args:
577            message: Log message
578            **kwargs: Additional fields to include in the log
579        """
580        if self._log_fn:
581            self._log_fn("warn", self._handler_name, message, kwargs)
582    
583    def warn(self, message: str, **kwargs: Any) -> None:
584        """Log a warning message (alias for warning).
585        
586        Args:
587            message: Log message
588            **kwargs: Additional fields to include in the log
589        """
590        self.warning(message, **kwargs)
591    
592    def debug(self, message: str, **kwargs: Any) -> None:
593        """Log a debug message.
594        
595        Args:
596            message: Log message
597            **kwargs: Additional fields to include in the log
598        """
599        if self._log_fn:
600            self._log_fn("debug", self._handler_name, message, kwargs)
601    
602    def trace(self, message: str, **kwargs: Any) -> None:
603        """Log a trace message.
604        
605        Args:
606            message: Log message
607            **kwargs: Additional fields to include in the log
608        """
609        if self._log_fn:
610            self._log_fn("trace", self._handler_name, message, kwargs)
611
612
613class State:
614    """Context object for handlers to trigger events and access runtime state."""
615    
616    def __init__(self, handler_name: Optional[str] = None, log_fn: Optional[Any] = None):
617        self._triggers: List[TriggeredEvent] = []
618        self._auto_trigger_payloads: Dict[str, Dict[str, Any]] = {}
619        self.logger = Logger(handler_name or "unknown", log_fn)
620    
621    def trigger_event(self, event_name: str, payload: Dict[str, Any]) -> None:
622        """Manually trigger an event with the given payload.
623        
624        Use this for events that are NOT defined in the schema's triggers list.
625        
626        Args:
627            event_name: Name of the event to trigger
628            payload: Event payload data (will be serialized to JSON)
629        """
630        self._triggers.append(TriggeredEvent(
631            event_name=event_name,
632            payload=payload
633        ))
634    
635    def set_payload(self, event_name: str, payload: Dict[str, Any]) -> None:
636        """Set the payload for an auto-triggered event.
637        
638        Use this for events that ARE defined in the schema's triggers list.
639        The event will be automatically triggered after the handler completes,
640        using the payload you set here.
641        
642        Args:
643            event_name: Name of the event (must match a trigger in schema)
644            payload: Event payload data (will be serialized to JSON)
645        """
646        self._auto_trigger_payloads[event_name] = payload
647    
648    def get_triggers(self) -> List[TriggeredEvent]:
649        """Get all manually triggered events. Used internally by the runtime."""
650        return self._triggers.copy()
651    
652    def get_auto_trigger_payload(self, event_name: str) -> Optional[Dict[str, Any]]:
653        """Get payload for an auto-triggered event. Used internally by the runtime."""
654        return self._auto_trigger_payloads.get(event_name)
655    
656    def get_all_auto_trigger_payloads(self) -> Dict[str, Dict[str, Any]]:
657        """Get all auto-trigger payloads. Used internally by the runtime."""
658        return self._auto_trigger_payloads.copy()
659"#;
660
661    fs::write(generated_dir.join("state.py"), content)?;
662    Ok(())
663}
664
665pub fn generate_init(schema: &Schema, output_dir: &Path) -> Result<()> {
666    let generated_dir = output_dir.join("generated");
667
668    let subdirs = ["models", "dto", "api", "events", "cron", "websockets"];
669    for subdir in &subdirs {
670        fs::write(generated_dir.join(format!("{}/__init__.py", subdir)), "")?;
671    }
672
673    let mut content = String::new();
674    content.push_str("# Generated by Rohas - Do not edit\n\n");
675
676    content.push_str("from .state import State, TriggeredEvent\n");
677
678    for model in &schema.models {
679        content.push_str(&format!(
680            "from .models.{} import {}\n",
681            templates::to_snake_case(&model.name),
682            model.name
683        ));
684    }
685
686    fs::write(generated_dir.join("__init__.py"), content)?;
687
688    Ok(())
689}