1use crate::error::Result;
2use crate::templates;
3use rohas_parser::{Api, Event, FieldType, Model, Schema, WebSocket};
4use std::fs;
5use std::path::Path;
6
7pub fn generate_models(schema: &Schema, output_dir: &Path) -> Result<()> {
8 let models_dir = output_dir.join("generated/models");
9
10 for model in &schema.models {
11 let content = generate_model_content(model);
12 let file_name = format!("{}.py", templates::to_snake_case(&model.name));
13 fs::write(models_dir.join(file_name), content)?;
14 }
15
16 Ok(())
17}
18
19fn generate_model_content(model: &Model) -> String {
20 let mut content = String::new();
21
22 content.push_str("from pydantic import BaseModel\n");
23 content.push_str("from typing import Optional\n");
24 content.push_str("from datetime import datetime\n\n");
25
26 content.push_str(&format!("class {}(BaseModel):\n", model.name));
27
28 for field in &model.fields {
29 let py_type = field.field_type.to_python();
30 let type_hint = if field.optional {
31 format!("Optional[{}]", py_type)
32 } else {
33 py_type
34 };
35 content.push_str(&format!(" {}: {}\n", field.name, type_hint));
36 }
37
38 if model.fields.is_empty() {
39 content.push_str(" pass\n");
40 }
41
42 content.push_str("\n class Config:\n");
43 content.push_str(" from_attributes = True\n");
44
45 content
46}
47
48pub fn generate_dtos(schema: &Schema, output_dir: &Path) -> Result<()> {
49 let dto_dir = output_dir.join("generated/dto");
50
51 for input in &schema.inputs {
52 let content = generate_model_content(&rohas_parser::Model {
53 name: input.name.clone(),
54 fields: input.fields.clone(),
55 attributes: vec![],
56 });
57 let file_name = format!("{}.py", templates::to_snake_case(&input.name));
58 fs::write(dto_dir.join(file_name), content)?;
59 }
60
61 Ok(())
62}
63
64pub fn generate_apis(schema: &Schema, output_dir: &Path) -> Result<()> {
65 let api_dir = output_dir.join("generated/api");
66
67 for api in &schema.apis {
68 let content = generate_api_content(api);
69 let file_name = format!("{}.py", templates::to_snake_case(&api.name));
70 fs::write(api_dir.join(file_name), content)?;
71 }
72
73 let handlers_dir = output_dir.join("handlers/api");
74 for api in &schema.apis {
75 let file_name = format!("{}.py", templates::to_snake_case(&api.name));
76 let handler_path = handlers_dir.join(&file_name);
77
78 if !handler_path.exists() {
79 let content = generate_api_handler_stub(api);
80 fs::write(handler_path, content)?;
81 }
82 }
83
84 Ok(())
85}
86
87fn extract_path_params(path: &str) -> Vec<String> {
91 let mut params = Vec::new();
92 let mut in_param = false;
93 let mut current_param = String::new();
94
95 for ch in path.chars() {
96 match ch {
97 '{' => {
98 in_param = true;
99 current_param.clear();
100 }
101 '}' => {
102 if in_param && !current_param.is_empty() {
103 params.push(current_param.clone());
104 }
105 in_param = false;
106 }
107 _ if in_param => {
108 current_param.push(ch);
109 }
110 _ => {}
111 }
112 }
113
114 params
115}
116
117fn generate_api_content(api: &Api) -> String {
118 let mut content = String::new();
119
120 content.push_str("from pydantic import BaseModel\n");
121 content.push_str("from typing import Callable, Awaitable, Dict, Optional\n");
122
123 let response_field_type = FieldType::from_str(&api.response);
124 let response_py_type = response_field_type.to_python();
125
126 let is_custom_type = matches!(response_field_type, FieldType::Custom(_));
127 if is_custom_type {
128 content.push_str(&format!(
129 "from ..models.{} import {}\n",
130 templates::to_snake_case(&api.response),
131 api.response
132 ));
133 }
134
135 if let Some(body) = &api.body {
136 content.push_str(&format!(
137 "from ..dto.{} import {}\n",
138 templates::to_snake_case(body),
139 body
140 ));
141 }
142
143 let path_params = extract_path_params(&api.path);
144
145 content.push_str(&format!("\nclass {}Request(BaseModel):\n", api.name));
146
147 for param in &path_params {
148 content.push_str(&format!(" {}: str\n", param));
149 }
150
151 if let Some(body) = &api.body {
152 content.push_str(&format!(" body: {}\n", body));
153 }
154
155 content.push_str(" query_params: Dict[str, str] = {}\n");
156
157 if path_params.is_empty() && api.body.is_none() {
158 }
160
161 content.push_str("\n class Config:\n");
162 content.push_str(" from_attributes = True\n");
163
164 content.push_str(&format!("\nclass {}Response(BaseModel):\n", api.name));
165 content.push_str(&format!(" data: {}\n", response_py_type));
166
167 content.push_str("\n class Config:\n");
168 content.push_str(" from_attributes = True\n");
169
170 content.push_str(&format!(
171 "\n{}Handler = Callable[[{}Request], Awaitable[{}Response]]\n",
172 api.name, api.name, api.name
173 ));
174
175 content
176}
177
178fn generate_api_handler_stub(api: &Api) -> String {
179 let mut content = String::new();
180
181 content.push_str(&format!(
182 "from generated.api.{} import {}Request, {}Response\n",
183 templates::to_snake_case(&api.name),
184 api.name,
185 api.name
186 ));
187 content.push_str("from generated.state import State\n\n");
188
189 content.push_str(&format!(
190 "async def handle_{}(req: {}Request, state: State) -> {}Response:\n",
191 templates::to_snake_case(&api.name),
192 api.name,
193 api.name
194 ));
195 content.push_str(" # TODO: Implement handler logic\n");
196 content.push_str(" # For auto-triggers (defined in schema triggers): use state.set_payload('EventName', {...})\n");
197 content.push_str(" # For manual triggers: use state.trigger_event('EventName', {...})\n");
198 content.push_str(" raise NotImplementedError('Handler not implemented')\n");
199
200 content
201}
202
203pub fn generate_events(schema: &Schema, output_dir: &Path) -> Result<()> {
204 let events_dir = output_dir.join("generated/events");
205
206 for event in &schema.events {
207 let content = generate_event_content(event);
208 let file_name = format!("{}.py", templates::to_snake_case(&event.name));
209 fs::write(events_dir.join(file_name), content)?;
210 }
211
212 let handlers_dir = output_dir.join("handlers/events");
213 for event in &schema.events {
214 for handler in &event.handlers {
215 let file_name = format!("{}.py", handler);
216 let handler_path = handlers_dir.join(&file_name);
217
218 if !handler_path.exists() {
219 let content = generate_event_handler_stub(event, handler);
220 fs::write(handler_path, content)?;
221 }
222 }
223 }
224
225 Ok(())
226}
227
228fn generate_event_content(event: &Event) -> String {
229 let mut content = String::new();
230
231 content.push_str("from pydantic import BaseModel\n");
232 content.push_str("from datetime import datetime\n");
233 content.push_str("from typing import Callable, Awaitable\n");
234
235 let payload_field_type = FieldType::from_str(&event.payload);
236 let payload_py_type = payload_field_type.to_python();
237
238 let is_custom_type = matches!(payload_field_type, FieldType::Custom(_));
239 if is_custom_type {
240 content.push_str(&format!(
241 "from ..models.{} import {}\n",
242 templates::to_snake_case(&event.payload),
243 event.payload
244 ));
245 }
246
247 content.push_str(&format!("\nclass {}(BaseModel):\n", event.name));
248 content.push_str(&format!(" payload: {}\n", payload_py_type));
249 content.push_str(" timestamp: datetime\n\n");
250
251 content.push_str(" class Config:\n");
252 content.push_str(" from_attributes = True\n\n");
253
254 content.push_str(&format!(
255 "{}Handler = Callable[[{}], Awaitable[None]]\n",
256 event.name, event.name
257 ));
258
259 content
260}
261
262fn generate_event_handler_stub(event: &Event, handler_name: &str) -> String {
263 let mut content = String::new();
264
265 content.push_str(&format!(
266 "from generated.events.{} import {}\n\n",
267 templates::to_snake_case(&event.name),
268 event.name
269 ));
270
271 content.push_str(&format!(
272 "async def {}(event: {}) -> None:\n",
273 handler_name, event.name
274 ));
275 content.push_str(" # TODO: Implement event handler\n");
276 content.push_str(&format!(" print(f'Handling event: {{event}}')\n"));
277
278 content
279}
280
281pub fn generate_crons(schema: &Schema, output_dir: &Path) -> Result<()> {
282 let handlers_dir = output_dir.join("handlers/cron");
283
284 for cron in &schema.crons {
285 let file_name = format!("{}.py", templates::to_snake_case(&cron.name));
286 let handler_path = handlers_dir.join(&file_name);
287
288 if !handler_path.exists() {
289 let content = format!(
290 "async def handle_{}() -> None:\n # TODO: Implement cron job\n print('Running cron: {}')\n",
291 templates::to_snake_case(&cron.name),
292 cron.name
293 );
294 fs::write(handler_path, content)?;
295 }
296 }
297
298 Ok(())
299}
300
301pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
302 let ws_dir = output_dir.join("generated/websockets");
303
304 for ws in &schema.websockets {
305 let content = generate_websocket_content(ws);
306 let file_name = format!("{}.py", templates::to_snake_case(&ws.name));
307 fs::write(ws_dir.join(file_name), content)?;
308 }
309
310 let handlers_dir = output_dir.join("handlers/websockets");
311 for ws in &schema.websockets {
312 if !ws.on_connect.is_empty() {
313 for handler in &ws.on_connect {
314 let file_name = format!("{}.py", handler);
315 let handler_path = handlers_dir.join(&file_name);
316 if !handler_path.exists() {
317 let content = generate_websocket_handler_stub(ws, "onConnect", handler);
318 fs::write(handler_path, content)?;
319 }
320 }
321 }
322 if !ws.on_message.is_empty() {
323 for handler in &ws.on_message {
324 let file_name = format!("{}.py", handler);
325 let handler_path = handlers_dir.join(&file_name);
326 if !handler_path.exists() {
327 let content = generate_websocket_handler_stub(ws, "onMessage", handler);
328 fs::write(handler_path, content)?;
329 }
330 }
331 }
332 if !ws.on_disconnect.is_empty() {
333 for handler in &ws.on_disconnect {
334 let file_name = format!("{}.py", handler);
335 let handler_path = handlers_dir.join(&file_name);
336 if !handler_path.exists() {
337 let content = generate_websocket_handler_stub(ws, "onDisconnect", handler);
338 fs::write(handler_path, content)?;
339 }
340 }
341 }
342 }
343
344 Ok(())
345}
346
347pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
348 use std::collections::HashSet;
349
350 let mut middleware_names = HashSet::new();
351
352 for api in &schema.apis {
353 for middleware in &api.middlewares {
354 middleware_names.insert(middleware.clone());
355 }
356 }
357
358 for ws in &schema.websockets {
359 for middleware in &ws.middlewares {
360 middleware_names.insert(middleware.clone());
361 }
362 }
363
364 if middleware_names.is_empty() {
365 return Ok(());
366 }
367
368 let middlewares_dir = output_dir.join("middlewares");
369 for middleware_name in middleware_names {
370 let file_name = format!("{}.py", templates::to_snake_case(&middleware_name));
371 let middleware_path = middlewares_dir.join(&file_name);
372
373 if !middleware_path.exists() {
374 let content = generate_middleware_stub(&middleware_name);
375 fs::write(middleware_path, content)?;
376 }
377 }
378
379 Ok(())
380}
381
382fn generate_middleware_stub(middleware_name: &str) -> String {
383 let mut content = String::new();
384
385 content.push_str("from typing import Dict, Any, Optional\n");
386 content.push_str("from generated.state import State\n\n");
387
388 content.push_str(&format!(
389 "async def {}_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]:\n",
390 templates::to_snake_case(middleware_name)
391 ));
392 content.push_str(" \"\"\"\n");
393 content.push_str(&format!(" Middleware function for {}.\n\n", middleware_name));
394 content.push_str(" Args:\n");
395 content.push_str(" context: Request context containing:\n");
396 content.push_str(" - payload: Request payload (for APIs)\n");
397 content.push_str(" - query_params: Query parameters (for APIs)\n");
398 content.push_str(" - connection: WebSocket connection info (for WebSockets)\n");
399 content.push_str(" - websocket_name: WebSocket name (for WebSockets)\n");
400 content.push_str(" - api_name: API name (for APIs)\n");
401 content.push_str(" - trace_id: Trace ID\n");
402 content.push_str(" state: State object for logging and triggering events\n\n");
403 content.push_str(" Returns:\n");
404 content.push_str(" Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys,\n");
405 content.push_str(" or None to pass through unchanged. Return a dict with 'error' key to reject the request.\n\n");
406 content.push_str(" To reject the request, raise an exception \n");
407 content.push_str(" \"\"\"\n");
408 content.push_str(" # TODO: Implement middleware logic\n");
409 content.push_str(" # Example: Validate authentication\n");
410 content.push_str(" # Example: Rate limiting\n");
411 content.push_str(" # Example: Logging\n");
412 content.push_str(" # Example: Modify payload/query_params\n");
413 content.push_str(" # \n");
414 content.push_str(" # To modify the request:\n");
415 content.push_str(" # return {\n");
416 content.push_str(" # 'payload': modified_payload,\n");
417 content.push_str(" # 'query_params': modified_query_params\n");
418 content.push_str(" # }\n");
419 content.push_str(" # \n");
420 content.push_str(" # To reject the request:\n");
421 content.push_str(" # raise Exception('Access denied')\n");
422 content.push_str(" \n");
423 content.push_str(" # Pass through unchanged\n");
424 content.push_str(" return None\n");
425
426 content
427}
428
429fn generate_websocket_content(ws: &WebSocket) -> String {
430 let mut content = String::new();
431
432 content.push_str("from pydantic import BaseModel\n");
433 content.push_str("from typing import Dict, Any, Optional\n");
434 content.push_str("from datetime import datetime\n\n");
435
436 if let Some(message_type) = &ws.message {
437 let message_field_type = FieldType::from_str(message_type);
438 let is_custom_type = matches!(message_field_type, FieldType::Custom(_));
439 if is_custom_type {
440 content.push_str(&format!(
441 "from ..dto.{} import {}\n",
442 templates::to_snake_case(message_type),
443 message_type
444 ));
445 }
446 }
447
448 content.push_str(&format!("class {}Message(BaseModel):\n", ws.name));
449 if let Some(message_type) = &ws.message {
450 let message_field_type = FieldType::from_str(message_type);
451 let py_type = message_field_type.to_python();
452 content.push_str(&format!(" data: {}\n", py_type));
453 } else {
454 content.push_str(" data: Dict[str, Any]\n");
455 }
456 content.push_str(" timestamp: datetime\n\n");
457 content.push_str(" class Config:\n");
458 content.push_str(" from_attributes = True\n\n");
459
460 content.push_str(&format!("class {}Connection(BaseModel):\n", ws.name));
461 content.push_str(" connection_id: str\n");
462 content.push_str(" path: str\n");
463 content.push_str(" connected_at: datetime\n\n");
464 content.push_str(" class Config:\n");
465 content.push_str(" from_attributes = True\n");
466
467 content
468}
469
470fn generate_websocket_handler_stub(
471 ws: &WebSocket,
472 handler_type: &str,
473 handler_name: &str,
474) -> String {
475 let mut content = String::new();
476
477 content.push_str(&format!(
478 "from generated.websockets.{} import {}Message, {}Connection\n",
479 templates::to_snake_case(&ws.name),
480 ws.name,
481 ws.name
482 ));
483 content.push_str("from generated.state import State\n");
484 content.push_str("from typing import Optional\n\n");
485
486 match handler_type {
487 "onConnect" => {
488 content.push_str(&format!(
489 "async def {}(connection: {}Connection, state: State) -> Optional[{}Message]:\n",
490 handler_name, ws.name, ws.name
491 ));
492 content.push_str(" # TODO: Implement onConnect handler\n");
493 content
494 .push_str(" # Return a message to send to the client on connection, or None\n");
495 content.push_str(&format!(
496 " print(f'Client connected: {{connection.connection_id}}')\n"
497 ));
498 content.push_str(" return None\n");
499 }
500 "onMessage" => {
501 content.push_str(&format!(
502 "async def {}(message: {}Message, connection: {}Connection, state: State) -> Optional[{}Message]:\n",
503 handler_name,
504 ws.name,
505 ws.name,
506 ws.name
507 ));
508 content.push_str(" # TODO: Implement onMessage handler\n");
509 content.push_str(" # Return a message to send back to the client, or None\n");
510 content.push_str(&format!(
511 " print(f'Received message: {{message.data}}')\n"
512 ));
513 content.push_str(" # For auto-triggers (defined in schema triggers): use state.set_payload('EventName', {...})\n");
514 content.push_str(
515 " # For manual triggers: use state.trigger_event('EventName', {...})\n",
516 );
517 content.push_str(" return None\n");
518 }
519 "onDisconnect" => {
520 content.push_str(&format!(
521 "async def {}(connection: {}Connection, state: State) -> None:\n",
522 handler_name, ws.name
523 ));
524 content.push_str(" # TODO: Implement onDisconnect handler\n");
525 content.push_str(&format!(
526 " print(f'Client disconnected: {{connection.connection_id}}')\n"
527 ));
528 }
529 _ => {}
530 }
531
532 content
533}
534
535pub fn generate_state(output_dir: &Path) -> Result<()> {
536 let generated_dir = output_dir.join("generated");
537 let content = r#"from typing import Any, Dict, List, Optional
538from pydantic import BaseModel
539
540
541class TriggeredEvent(BaseModel):
542 event_name: str
543 payload: Dict[str, Any]
544
545
546class Logger:
547 """Logger for handlers to emit structured logs."""
548
549 def __init__(self, handler_name: str, log_fn: Any):
550 self._handler_name = handler_name
551 self._log_fn = log_fn
552
553 def info(self, message: str, **kwargs: Any) -> None:
554 """Log an info message.
555
556 Args:
557 message: Log message
558 **kwargs: Additional fields to include in the log
559 """
560 if self._log_fn:
561 self._log_fn("info", self._handler_name, message, kwargs)
562
563 def error(self, message: str, **kwargs: Any) -> None:
564 """Log an error message.
565
566 Args:
567 message: Log message
568 **kwargs: Additional fields to include in the log
569 """
570 if self._log_fn:
571 self._log_fn("error", self._handler_name, message, kwargs)
572
573 def warning(self, message: str, **kwargs: Any) -> None:
574 """Log a warning message.
575
576 Args:
577 message: Log message
578 **kwargs: Additional fields to include in the log
579 """
580 if self._log_fn:
581 self._log_fn("warn", self._handler_name, message, kwargs)
582
583 def warn(self, message: str, **kwargs: Any) -> None:
584 """Log a warning message (alias for warning).
585
586 Args:
587 message: Log message
588 **kwargs: Additional fields to include in the log
589 """
590 self.warning(message, **kwargs)
591
592 def debug(self, message: str, **kwargs: Any) -> None:
593 """Log a debug message.
594
595 Args:
596 message: Log message
597 **kwargs: Additional fields to include in the log
598 """
599 if self._log_fn:
600 self._log_fn("debug", self._handler_name, message, kwargs)
601
602 def trace(self, message: str, **kwargs: Any) -> None:
603 """Log a trace message.
604
605 Args:
606 message: Log message
607 **kwargs: Additional fields to include in the log
608 """
609 if self._log_fn:
610 self._log_fn("trace", self._handler_name, message, kwargs)
611
612
613class State:
614 """Context object for handlers to trigger events and access runtime state."""
615
616 def __init__(self, handler_name: Optional[str] = None, log_fn: Optional[Any] = None):
617 self._triggers: List[TriggeredEvent] = []
618 self._auto_trigger_payloads: Dict[str, Dict[str, Any]] = {}
619 self.logger = Logger(handler_name or "unknown", log_fn)
620
621 def trigger_event(self, event_name: str, payload: Dict[str, Any]) -> None:
622 """Manually trigger an event with the given payload.
623
624 Use this for events that are NOT defined in the schema's triggers list.
625
626 Args:
627 event_name: Name of the event to trigger
628 payload: Event payload data (will be serialized to JSON)
629 """
630 self._triggers.append(TriggeredEvent(
631 event_name=event_name,
632 payload=payload
633 ))
634
635 def set_payload(self, event_name: str, payload: Dict[str, Any]) -> None:
636 """Set the payload for an auto-triggered event.
637
638 Use this for events that ARE defined in the schema's triggers list.
639 The event will be automatically triggered after the handler completes,
640 using the payload you set here.
641
642 Args:
643 event_name: Name of the event (must match a trigger in schema)
644 payload: Event payload data (will be serialized to JSON)
645 """
646 self._auto_trigger_payloads[event_name] = payload
647
648 def get_triggers(self) -> List[TriggeredEvent]:
649 """Get all manually triggered events. Used internally by the runtime."""
650 return self._triggers.copy()
651
652 def get_auto_trigger_payload(self, event_name: str) -> Optional[Dict[str, Any]]:
653 """Get payload for an auto-triggered event. Used internally by the runtime."""
654 return self._auto_trigger_payloads.get(event_name)
655
656 def get_all_auto_trigger_payloads(self) -> Dict[str, Dict[str, Any]]:
657 """Get all auto-trigger payloads. Used internally by the runtime."""
658 return self._auto_trigger_payloads.copy()
659"#;
660
661 fs::write(generated_dir.join("state.py"), content)?;
662 Ok(())
663}
664
665pub fn generate_init(schema: &Schema, output_dir: &Path) -> Result<()> {
666 let generated_dir = output_dir.join("generated");
667
668 let subdirs = ["models", "dto", "api", "events", "cron", "websockets"];
669 for subdir in &subdirs {
670 fs::write(generated_dir.join(format!("{}/__init__.py", subdir)), "")?;
671 }
672
673 let mut content = String::new();
674 content.push_str("# Generated by Rohas - Do not edit\n\n");
675
676 content.push_str("from .state import State, TriggeredEvent\n");
677
678 for model in &schema.models {
679 content.push_str(&format!(
680 "from .models.{} import {}\n",
681 templates::to_snake_case(&model.name),
682 model.name
683 ));
684 }
685
686 fs::write(generated_dir.join("__init__.py"), content)?;
687
688 Ok(())
689}