1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4
5pub fn extract(file: &ParsedFile) -> Option<Capability> {
7 let root = file.tree.root_node();
8 let source = file.source.as_bytes();
9 let full_text = &file.source;
10
11 if !full_text.contains("Column(")
13 && !full_text.contains("mapped_column(")
14 && !full_text.contains("models.Model")
15 && !full_text.contains("BaseModel")
16 && !full_text.contains("SQLModel")
17 && !full_text.contains("Base)")
18 && !full_text.contains("DeclarativeBase")
19 {
20 return None;
21 }
22
23 let mut entities = Vec::new();
24
25 let mut known_model_classes: Vec<String> = Vec::new();
27
28 for i in 0..root.named_child_count() {
30 let node = root.named_child(i).unwrap();
31 let class_node = if node.kind() == "class_definition" {
32 node
33 } else if node.kind() == "decorated_definition" {
34 match get_inner_definition(&node) {
35 Some(n) if n.kind() == "class_definition" => n,
36 _ => continue,
37 }
38 } else {
39 continue;
40 };
41
42 let name = get_def_name(&class_node, source);
43 let bases = get_class_bases(&class_node, source);
44
45 let is_direct_model = bases.iter().any(|b| {
47 b == "Base"
48 || b.contains("DeclarativeBase")
49 || b == "models.Model"
50 || b.starts_with("models.")
51 || b == "BaseModel"
52 || b == "SQLModel"
53 });
54
55 if is_direct_model || has_table_kwarg(&class_node, source) {
56 known_model_classes.push(name);
57 }
58 }
59
60 let mut changed = true;
63 while changed {
64 changed = false;
65 for i in 0..root.named_child_count() {
66 let node = root.named_child(i).unwrap();
67 let class_node = if node.kind() == "class_definition" {
68 node
69 } else if node.kind() == "decorated_definition" {
70 match get_inner_definition(&node) {
71 Some(n) if n.kind() == "class_definition" => n,
72 _ => continue,
73 }
74 } else {
75 continue;
76 };
77
78 let name = get_def_name(&class_node, source);
79 if known_model_classes.contains(&name) {
80 continue;
81 }
82 let bases = get_class_bases(&class_node, source);
83 if bases
84 .iter()
85 .any(|b| known_model_classes.iter().any(|k| k == b))
86 {
87 known_model_classes.push(name);
88 changed = true;
89 }
90 }
91 }
92
93 for i in 0..root.named_child_count() {
94 let node = root.named_child(i).unwrap();
95
96 let class_node = if node.kind() == "class_definition" {
97 node
98 } else if node.kind() == "decorated_definition" {
99 match get_inner_definition(&node) {
100 Some(n) if n.kind() == "class_definition" => n,
101 _ => continue,
102 }
103 } else {
104 continue;
105 };
106
107 let class_name = get_def_name(&class_node, source);
108
109 if is_base_class_definition(&class_name) {
111 continue;
112 }
113
114 let bases = get_class_bases(&class_node, source);
115
116 let has_table_true = has_table_kwarg(&class_node, source);
118
119 let class_bases: Vec<String> = bases.iter().filter(|b| !b.contains('=')).cloned().collect();
121
122 if bases
124 .iter()
125 .any(|b| b == "Base" || b.contains("DeclarativeBase"))
126 && let Some(mut entity) = extract_sqlalchemy_entity(&class_node, source, &class_name)
127 {
128 entity.bases = class_bases;
129 entities.push(entity);
130 }
131 else if has_table_true
133 && full_text.contains("SQLModel")
134 && let Some(mut entity) = extract_sqlmodel_entity(&class_node, source, &class_name)
135 {
136 entity.bases = class_bases;
137 entities.push(entity);
138 }
139 else if bases
141 .iter()
142 .any(|b| b == "models.Model" || b.starts_with("models."))
143 && let Some(mut entity) = extract_django_model(&class_node, source, &class_name)
144 {
145 entity.bases = class_bases;
146 entities.push(entity);
147 }
148 else if bases.iter().any(|b| {
150 b == "BaseModel" || b == "SQLModel" || known_model_classes.iter().any(|k| k == b)
151 }) && let Some(mut entity) =
152 extract_pydantic_model(&class_node, source, &class_name)
153 {
154 entity.bases = class_bases;
155 entities.push(entity);
156 }
157 }
158
159 if entities.is_empty() {
160 return None;
161 }
162
163 let file_stem = file
164 .path
165 .rsplit('/')
166 .next()
167 .unwrap_or(&file.path)
168 .trim_end_matches(".py");
169 let capability_name = format!("{}-model", to_kebab_case(file_stem));
170
171 let mut capability = Capability::new(capability_name, file.path.clone());
172 capability.entities = entities;
173 Some(capability)
174}
175
176fn get_class_bases(class_node: &tree_sitter::Node, source: &[u8]) -> Vec<String> {
177 let mut bases = Vec::new();
178 if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
179 for i in 0..arg_list.named_child_count() {
180 let arg = arg_list.named_child(i).unwrap();
181 bases.push(node_text(&arg, source));
182 }
183 }
184 bases
185}
186
187fn has_table_kwarg(class_node: &tree_sitter::Node, source: &[u8]) -> bool {
189 if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
190 for i in 0..arg_list.named_child_count() {
191 let arg = arg_list.named_child(i).unwrap();
192 if arg.kind() == "keyword_argument" {
193 let text = node_text(&arg, source);
194 if text.contains("table") && text.contains("True") {
195 return true;
196 }
197 }
198 }
199 }
200 false
201}
202
203fn extract_sqlmodel_entity(
205 class_node: &tree_sitter::Node,
206 source: &[u8],
207 class_name: &str,
208) -> Option<Entity> {
209 let body = class_node.child_by_field_name("body")?;
210 let mut fields = Vec::new();
211 let table_name = class_name.to_lowercase();
212
213 for i in 0..body.named_child_count() {
214 let stmt = body.named_child(i).unwrap();
215 if stmt.kind() != "expression_statement" {
216 continue;
217 }
218
219 let text = node_text(&stmt, source);
220
221 if !text.contains(':') || text.starts_with('#') {
223 continue;
224 }
225
226 let parts: Vec<&str> = text.splitn(2, ':').collect();
228 if parts.len() != 2 {
229 continue;
230 }
231
232 let name = parts[0].trim().to_string();
233 if name.starts_with('_') || name == "model_config" {
234 continue;
235 }
236
237 let type_and_default = parts[1].trim();
238 let field_type = type_and_default
239 .split('=')
240 .next()
241 .unwrap_or("")
242 .trim()
243 .to_string();
244
245 if text.contains("Relationship(") {
247 fields.push(EntityField {
248 name,
249 field_type: format!("relationship({})", field_type),
250 constraints: vec!["relationship".to_string()],
251 });
252 continue;
253 }
254
255 let mut constraints = Vec::new();
256 if text.contains("Field(") {
257 if text.contains("primary_key=True") {
258 constraints.push("primary_key".to_string());
259 }
260 if text.contains("unique=True") {
261 constraints.push("unique".to_string());
262 }
263 if text.contains("index=True") {
264 constraints.push("indexed".to_string());
265 }
266 if text.contains("foreign_key=") {
267 constraints.push("relationship".to_string());
268 }
269 if text.contains("nullable=False") {
270 constraints.push("required".to_string());
271 }
272 if let Some(v) = extract_kwarg_value(&text, "max_length") {
273 constraints.push(format!("max_length={}", v));
274 }
275 if let Some(v) = extract_kwarg_value(&text, "min_length") {
276 constraints.push(format!("min_length={}", v));
277 }
278 }
279
280 if !text.contains("| None")
282 && !text.contains("Optional")
283 && !text.contains("= None")
284 && !constraints.contains(&"required".to_string())
285 {
286 constraints.push("required".to_string());
287 }
288
289 fields.push(EntityField {
290 name,
291 field_type,
292 constraints,
293 });
294 }
295
296 Some(Entity {
297 name: class_name.to_string(),
298 table: table_name,
299 fields,
300 bases: Vec::new(),
301 })
302}
303
304fn extract_sqlalchemy_entity(
306 class_node: &tree_sitter::Node,
307 source: &[u8],
308 class_name: &str,
309) -> Option<Entity> {
310 let body = class_node.child_by_field_name("body")?;
311 let mut fields = Vec::new();
312 let mut table_name = class_name.to_lowercase();
313
314 for i in 0..body.named_child_count() {
315 let stmt = body.named_child(i).unwrap();
316 if stmt.kind() != "expression_statement" {
317 continue;
318 }
319
320 let text = node_text(&stmt, source);
321
322 if text.contains("__tablename__") {
324 if let Some(val) = extract_assignment_string(&text) {
325 table_name = val;
326 }
327 continue;
328 }
329
330 if (text.contains("Column(")
332 || text.contains("relationship(")
333 || text.contains("mapped_column("))
334 && let Some(field) = parse_sqlalchemy_field(&text)
335 {
336 fields.push(field);
337 }
338 }
339
340 Some(Entity {
341 name: class_name.to_string(),
342 table: table_name,
343 fields,
344 bases: Vec::new(),
345 })
346}
347
348fn parse_sqlalchemy_field(text: &str) -> Option<EntityField> {
349 let (name, rhs) = if text.contains("mapped_column(") {
351 let colon_parts: Vec<&str> = text.splitn(2, ':').collect();
353 if colon_parts.len() != 2 {
354 return None;
355 }
356 let field_name = colon_parts[0].trim().to_string();
357 let after_colon = colon_parts[1].trim();
358 if let Some(eq_pos) = after_colon.find("= mapped_column(") {
360 (field_name, after_colon[eq_pos + 2..].trim().to_string())
361 } else if let Some(eq_pos) = after_colon.find("=mapped_column(") {
362 (field_name, after_colon[eq_pos + 1..].trim().to_string())
363 } else {
364 return None;
365 }
366 } else {
367 let parts: Vec<&str> = text.splitn(2, '=').collect();
368 if parts.len() != 2 {
369 return None;
370 }
371 (parts[0].trim().to_string(), parts[1].trim().to_string())
372 };
373
374 if rhs.starts_with("Column(") || rhs.starts_with("mapped_column(") {
375 let prefix_len = if rhs.starts_with("mapped_column(") {
376 14
377 } else {
378 7
379 };
380 let inner = &rhs[prefix_len..rhs.rfind(')')?];
381 let args: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
382 let field_type = args.first().unwrap_or(&"").to_string();
383
384 let mut constraints = Vec::new();
385 for arg in &args[1..] {
386 if arg.contains("primary_key=True") {
387 constraints.push("primary_key".to_string());
388 }
389 if arg.contains("nullable=False") {
390 constraints.push("required".to_string());
391 }
392 if arg.contains("unique=True") {
393 constraints.push("unique".to_string());
394 }
395 if arg.contains("index=True") {
396 constraints.push("indexed".to_string());
397 }
398 }
399
400 Some(EntityField {
401 name,
402 field_type,
403 constraints,
404 })
405 } else if rhs.starts_with("relationship(") {
406 let inner = &rhs[13..rhs.rfind(')')?];
407 let target = inner
408 .split(',')
409 .next()?
410 .trim()
411 .trim_matches('"')
412 .trim_matches('\'');
413 Some(EntityField {
414 name,
415 field_type: format!("relationship({})", target),
416 constraints: vec!["relationship".to_string()],
417 })
418 } else {
419 None
420 }
421}
422
423fn is_base_class_definition(class_name: &str) -> bool {
425 class_name == "Base" || class_name == "DeclarativeBase"
426}
427
428fn extract_django_model(
430 class_node: &tree_sitter::Node,
431 source: &[u8],
432 class_name: &str,
433) -> Option<Entity> {
434 let body = class_node.child_by_field_name("body")?;
435 let mut fields = Vec::new();
436 let table_name = class_name.to_lowercase();
437
438 for i in 0..body.named_child_count() {
439 let stmt = body.named_child(i).unwrap();
440 if stmt.kind() != "expression_statement" {
441 continue;
442 }
443
444 let text = node_text(&stmt, source);
445
446 if (text.contains("Field(")
448 || text.contains("ForeignKey(")
449 || text.contains("ManyToManyField(")
450 || text.contains("OneToOneField("))
451 && let Some(field) = parse_django_field(&text)
452 {
453 fields.push(field);
454 }
455 }
456
457 Some(Entity {
458 name: class_name.to_string(),
459 table: table_name,
460 fields,
461 bases: Vec::new(),
462 })
463}
464
465fn parse_django_field(text: &str) -> Option<EntityField> {
466 let parts: Vec<&str> = text.splitn(2, '=').collect();
467 if parts.len() != 2 {
468 return None;
469 }
470
471 let name = parts[0].trim().to_string();
472 let rhs = parts[1].trim();
473
474 let field_type = rhs.split('(').next()?.trim().replace("models.", "");
476
477 let mut constraints = Vec::new();
478 if rhs.contains("primary_key=True") {
479 constraints.push("primary_key".to_string());
480 }
481 if rhs.contains("blank=False")
482 || rhs.contains("null=False")
483 || !rhs.contains("blank=True") && !rhs.contains("null=True")
484 {
485 constraints.push("required".to_string());
486 }
487 if rhs.contains("unique=True") {
488 constraints.push("unique".to_string());
489 }
490 if rhs.contains("max_length=")
491 && let Some(ml) = extract_kwarg_value(rhs, "max_length")
492 {
493 constraints.push(format!("max_length={}", ml));
494 }
495 if field_type.contains("ForeignKey") || field_type.contains("OneToOne") {
496 constraints.push("relationship".to_string());
497 }
498 if field_type.contains("ManyToMany") {
499 constraints.push("many_to_many".to_string());
500 }
501
502 Some(EntityField {
503 name,
504 field_type,
505 constraints,
506 })
507}
508
509fn extract_pydantic_model(
511 class_node: &tree_sitter::Node,
512 source: &[u8],
513 class_name: &str,
514) -> Option<Entity> {
515 let body = class_node.child_by_field_name("body")?;
516 let mut fields = Vec::new();
517
518 for i in 0..body.named_child_count() {
519 let stmt = body.named_child(i).unwrap();
520
521 let text = node_text(&stmt, source);
522
523 if stmt.kind() == "expression_statement" && text.contains(':') {
525 let parts: Vec<&str> = text.splitn(2, ':').collect();
526 if parts.len() == 2 {
527 let name = parts[0].trim().to_string();
528 let type_and_default = parts[1].trim();
529 let field_type = type_and_default
530 .split('=')
531 .next()
532 .unwrap_or("")
533 .trim()
534 .to_string();
535
536 if name.starts_with('_') || name == "model_config" || name == "Config" {
537 continue;
538 }
539
540 let mut constraints = Vec::new();
541 if text.contains("Field(") {
542 if let Some(v) = extract_kwarg_value(&text, "min_length") {
543 constraints.push(format!("min_length={}", v));
544 }
545 if let Some(v) = extract_kwarg_value(&text, "max_length") {
546 constraints.push(format!("max_length={}", v));
547 }
548 if text.contains("gt=") || text.contains("ge=") {
549 constraints.push("min_value".to_string());
550 }
551 if text.contains("lt=") || text.contains("le=") {
552 constraints.push("max_value".to_string());
553 }
554 }
555
556 if !field_type.starts_with("Optional") && !text.contains("= None") {
557 constraints.push("required".to_string());
558 }
559
560 fields.push(EntityField {
561 name,
562 field_type,
563 constraints,
564 });
565 }
566 }
567 }
568
569 Some(Entity {
570 name: class_name.to_string(),
571 table: class_name.to_lowercase(),
572 fields,
573 bases: Vec::new(),
574 })
575}
576
577fn extract_assignment_string(text: &str) -> Option<String> {
578 let after_eq = text.split('=').nth(1)?.trim();
579 Some(clean_string_literal(after_eq))
580}
581
582fn extract_kwarg_value<'a>(text: &'a str, key: &str) -> Option<&'a str> {
583 let pattern = format!("{}=", key);
584 let start = text.find(&pattern)? + pattern.len();
585 let remaining = &text[start..];
586 let end = remaining.find([',', ')'])?;
587 Some(remaining[..end].trim())
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use crate::context::ParsedFile;
594
595 fn parse_file(source: &str, path: &str) -> ParsedFile {
596 ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
597 }
598
599 #[test]
600 fn test_sqlalchemy_model() {
601 let source = r#"
602from sqlalchemy import Column, Integer, String, ForeignKey
603from sqlalchemy.orm import relationship
604
605class User(Base):
606 __tablename__ = "users"
607
608 id = Column(Integer, primary_key=True)
609 name = Column(String, nullable=False)
610 email = Column(String, unique=True)
611 posts = relationship("Post")
612"#;
613
614 let file = parse_file(source, "models/user.py");
615 let capability = extract(&file).unwrap();
616
617 let entity = &capability.entities[0];
618 assert_eq!(entity.name, "User");
619 assert_eq!(entity.table, "users");
620 assert_eq!(entity.fields.len(), 4);
621 assert!(
622 entity.fields[0]
623 .constraints
624 .contains(&"primary_key".to_string())
625 );
626 assert!(entity.fields[2].constraints.contains(&"unique".to_string()));
627 }
628
629 #[test]
630 fn test_django_model() {
631 let source = r#"
632from django.db import models
633
634class Article(models.Model):
635 title = models.CharField(max_length=200)
636 content = models.TextField()
637 author = models.ForeignKey("User", on_delete=models.CASCADE)
638 tags = models.ManyToManyField("Tag")
639"#;
640
641 let file = parse_file(source, "models.py");
642 let capability = extract(&file).unwrap();
643
644 let entity = &capability.entities[0];
645 assert_eq!(entity.name, "Article");
646 assert_eq!(entity.fields.len(), 4);
647 assert!(
648 entity.fields[0]
649 .constraints
650 .iter()
651 .any(|c| c.contains("max_length"))
652 );
653 assert!(
654 entity.fields[2]
655 .constraints
656 .contains(&"relationship".to_string())
657 );
658 }
659
660 #[test]
661 fn test_pydantic_model() {
662 let source = r#"
663from pydantic import BaseModel, Field
664
665class UserCreate(BaseModel):
666 name: str = Field(min_length=2, max_length=50)
667 email: str
668 age: int = Field(gt=0, lt=150)
669 bio: Optional[str] = None
670"#;
671
672 let file = parse_file(source, "schemas/user.py");
673 let capability = extract(&file).unwrap();
674
675 let entity = &capability.entities[0];
676 assert_eq!(entity.name, "UserCreate");
677 assert_eq!(entity.fields.len(), 4);
678 assert!(
679 entity.fields[0]
680 .constraints
681 .iter()
682 .any(|c| c.contains("min_length"))
683 );
684 assert!(
686 !entity.fields[3]
687 .constraints
688 .contains(&"required".to_string())
689 );
690 }
691
692 #[test]
693 fn test_no_model() {
694 let source = r#"
695class Helper:
696 def do_something(self):
697 pass
698"#;
699 let file = parse_file(source, "utils.py");
700 assert!(extract(&file).is_none());
701 }
702
703 #[test]
704 fn test_sqlmodel_table_entity() {
705 let source = r#"
706from sqlmodel import Field, SQLModel, Relationship
707
708class UserBase(SQLModel):
709 email: str = Field(unique=True, max_length=255)
710 is_active: bool = True
711
712class User(UserBase, table=True):
713 id: int = Field(primary_key=True)
714 hashed_password: str
715 items: list["Item"] = Relationship(back_populates="owner")
716"#;
717
718 let file = parse_file(source, "models.py");
719 let capability = extract(&file).unwrap();
720
721 assert!(capability.entities.len() >= 2);
723
724 let user = capability
725 .entities
726 .iter()
727 .find(|e| e.name == "User")
728 .unwrap();
729 assert_eq!(user.table, "user");
730 assert!(
731 user.fields
732 .iter()
733 .any(|f| f.name == "id" && f.constraints.contains(&"primary_key".to_string()))
734 );
735 assert!(user.fields.iter().any(|f| f.name == "hashed_password"));
736 assert!(
737 user.fields
738 .iter()
739 .any(|f| f.constraints.contains(&"relationship".to_string()))
740 );
741 }
742
743 #[test]
744 fn test_sqlalchemy_mapped_column() {
745 let source = r#"
746from sqlalchemy.orm import Mapped, mapped_column
747from sqlalchemy import Integer, String, Float
748
749class Trade(Base):
750 __tablename__ = "trades"
751
752 id: Mapped[int] = mapped_column(Integer, primary_key=True)
753 ticker: Mapped[str] = mapped_column(String(20))
754 price: Mapped[float] = mapped_column(Float)
755 status: Mapped[str] = mapped_column(String(20), unique=True)
756"#;
757
758 let file = parse_file(source, "models.py");
759 let capability = extract(&file).unwrap();
760
761 let trade = &capability.entities[0];
762 assert_eq!(trade.name, "Trade");
763 assert_eq!(trade.table, "trades");
764 assert_eq!(trade.fields.len(), 4);
765 assert!(
766 trade.fields[0]
767 .constraints
768 .contains(&"primary_key".to_string())
769 );
770 assert_eq!(trade.fields[1].name, "ticker");
771 assert_eq!(trade.fields[1].field_type, "String(20)");
772 assert!(trade.fields[3].constraints.contains(&"unique".to_string()));
773 }
774
775 #[test]
776 fn test_sqlmodel_inheritance_chain() {
777 let source = r#"
779from sqlmodel import Field, SQLModel
780
781class UserBase(SQLModel):
782 email: str = Field(max_length=255)
783
784class UserCreate(UserBase):
785 password: str = Field(min_length=8)
786
787class ItemBase(SQLModel):
788 title: str
789
790class ItemCreate(ItemBase):
791 pass
792"#;
793
794 let file = parse_file(source, "models.py");
795 let capability = extract(&file).unwrap();
796
797 let names: Vec<&str> = capability
798 .entities
799 .iter()
800 .map(|e| e.name.as_str())
801 .collect();
802 assert!(
803 names.contains(&"UserBase"),
804 "Should find UserBase, got: {:?}",
805 names
806 );
807 assert!(
808 names.contains(&"UserCreate"),
809 "Should find UserCreate (inherits UserBase), got: {:?}",
810 names
811 );
812 assert!(
813 names.contains(&"ItemBase"),
814 "Should find ItemBase, got: {:?}",
815 names
816 );
817 assert!(
818 names.contains(&"ItemCreate"),
819 "Should find ItemCreate (inherits ItemBase), got: {:?}",
820 names
821 );
822
823 let user_create = capability
825 .entities
826 .iter()
827 .find(|e| e.name == "UserCreate")
828 .unwrap();
829 assert_eq!(user_create.fields.len(), 1);
830 assert_eq!(user_create.fields[0].name, "password");
831
832 let item_create = capability
834 .entities
835 .iter()
836 .find(|e| e.name == "ItemCreate")
837 .unwrap();
838 assert_eq!(item_create.fields.len(), 0);
839 }
840
841 #[test]
842 fn test_base_class_not_entity() {
843 let source = r#"
844from sqlalchemy.orm import DeclarativeBase
845
846class Base(DeclarativeBase):
847 pass
848"#;
849 let file = parse_file(source, "database.py");
850 assert!(extract(&file).is_none());
851 }
852}