import os
import sys
import tempfile
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, cast
from unittest.mock import MagicMock, PropertyMock, mock_open, patch
import pytest
from sqlalchemy import (
JSON,
UUID,
Boolean,
Column,
Date,
DateTime,
)
from sqlalchemy import Enum as SQLAlchemyEnum
from sqlalchemy import (
Float,
Integer,
Interval,
Numeric,
String,
Time,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
import nomy_data_models
from nomy_data_models.models.base import BaseModel
from nomy_data_models.py_to_rust import (
_generate_rust_fields,
_print_unknown_type_warning,
generate_rust_enum,
generate_rust_model,
generate_rust_models,
get_all_models,
get_required_imports,
sqlalchemy_to_rust_type,
)
class StatusEnum(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"
class DirectionEnum(str, Enum):
UP = "up"
DOWN = "down"
SIDEWAYS = "sideways"
class RustModelExample(BaseModel):
__abstract__ = False
name: Mapped[str] = mapped_column(String(50), nullable=False)
age: Mapped[int] = mapped_column(Integer, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[DateTime] = mapped_column(DateTime, nullable=False)
model_id: Mapped[UUID] = mapped_column(UUID, nullable=False)
data: Mapped[dict] = mapped_column(JSON, nullable=True)
amount: Mapped[float] = mapped_column(Float, nullable=True)
decimal_value: Mapped[float] = mapped_column(Numeric, nullable=True)
date_value: Mapped[Date] = mapped_column(Date, nullable=True)
time_value: Mapped[Time] = mapped_column(Time, nullable=True)
interval_value: Mapped[Interval] = mapped_column(Interval, nullable=True)
class AbstractModel(BaseModel):
__abstract__ = True
class ConcreteAbstractModel(AbstractModel):
__abstract__ = False
name: Mapped[str] = mapped_column(String(50), nullable=False)
class ModelWithTableArgs(BaseModel):
__abstract__ = False
__table_args__ = {"sqlite_autoincrement": True}
name: Mapped[str] = mapped_column(String(50), nullable=False)
@pytest.fixture
def test_status_enum():
return StatusEnum
@pytest.fixture
def test_direction_enum():
return DirectionEnum
@pytest.fixture
def test_rust_model():
return RustModelExample
class TestRustEnumGeneration:
def test_generate_rust_enum(self, test_status_enum):
rust_enum = generate_rust_enum(test_status_enum)
assert (
"#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]"
in rust_enum
), "Rust enum should have the correct derive attributes"
assert (
"pub enum StatusEnum {" in rust_enum
), "Rust enum should have the correct name"
assert (
'#[serde(rename = "active")]' in rust_enum
), "ACTIVE variant should have correct serde rename"
assert "ACTIVE," in rust_enum, "ACTIVE variant should be present"
assert (
'#[serde(rename = "inactive")]' in rust_enum
), "INACTIVE variant should have correct serde rename"
assert "INACTIVE," in rust_enum, "INACTIVE variant should be present"
assert (
'#[serde(rename = "pending")]' in rust_enum
), "PENDING variant should have correct serde rename"
assert "PENDING," in rust_enum, "PENDING variant should be present"
assert "impl StatusEnum {" in rust_enum, "Rust enum should have an impl block"
assert (
"pub fn as_str(&self) -> &'static str {" in rust_enum
), "as_str method should be present"
assert "match self {" in rust_enum, "match statement should be present"
assert (
'StatusEnum::ACTIVE => "active",' in rust_enum
), "ACTIVE match arm should be present"
def test_generate_rust_enum_with_different_enum(self, test_direction_enum):
rust_enum = generate_rust_enum(test_direction_enum)
assert (
"#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]"
in rust_enum
), "Rust enum should have the correct derive attributes"
assert (
"pub enum DirectionEnum {" in rust_enum
), "Rust enum should have the correct name"
assert (
'#[serde(rename = "up")]' in rust_enum
), "UP variant should have correct serde rename"
assert "UP," in rust_enum, "UP variant should be present"
assert (
'#[serde(rename = "down")]' in rust_enum
), "DOWN variant should have correct serde rename"
assert "DOWN," in rust_enum, "DOWN variant should be present"
assert (
'#[serde(rename = "sideways")]' in rust_enum
), "SIDEWAYS variant should have correct serde rename"
assert "SIDEWAYS," in rust_enum, "SIDEWAYS variant should be present"
assert (
"impl DirectionEnum {" in rust_enum
), "Rust enum should have an impl block"
assert (
"pub fn as_str(&self) -> &'static str {" in rust_enum
), "as_str method should be present"
assert "match self {" in rust_enum, "match statement should be present"
assert (
'DirectionEnum::UP => "up",' in rust_enum
), "UP match arm should be present"
assert (
'DirectionEnum::DOWN => "down",' in rust_enum
), "DOWN match arm should be present"
assert (
'DirectionEnum::SIDEWAYS => "sideways",' in rust_enum
), "SIDEWAYS match arm should be present"
def test_generate_rust_enum_with_numeric_values(self):
class NumericEnum(Enum):
ONE = 1
TWO = 2
THREE = 3
rust_enum = generate_rust_enum(NumericEnum)
assert "pub enum NumericEnum {" in rust_enum
assert "ONE = 1," in rust_enum
assert "TWO = 2," in rust_enum
assert "THREE = 3," in rust_enum
assert 'numericenum::one => "one",' in rust_enum.lower()
class TestRustTypeConversion:
@pytest.mark.parametrize(
"sqlalchemy_type,expected_rust_type",
[
("Integer", "i32"),
("BigInteger", "i64"),
("SmallInteger", "i16"),
("String", "String"),
("Text", "String"),
("Boolean", "bool"),
("Float", "f64"),
("Numeric", "Decimal"),
("DateTime", "DateTime<Utc>"),
("Date", "NaiveDate"),
("Time", "NaiveTime"),
("UUID", "Uuid"),
("JSON", "JsonValue"),
("JSONB", "JsonValue"),
("Interval", "chrono::Duration"),
("Enum", "String"),
("uuid", "Uuid"), ("UnknownType", "String"), ],
)
def test_sqlalchemy_to_rust_type(self, sqlalchemy_type, expected_rust_type):
result = sqlalchemy_to_rust_type(sqlalchemy_type)
assert (
result == expected_rust_type
), f"Expected {sqlalchemy_type} to convert to {expected_rust_type}, got {result}"
class TestSQLAlchemyModelConversion:
def test_get_required_imports(self, test_rust_model):
with patch(
"nomy_data_models.py_to_rust.get_required_imports", return_value=set()
):
imports = get_required_imports(test_rust_model)
assert isinstance(imports, set), "Result should be a set"
assert any(
"use chrono" in imp for imp in imports
), "Should include chrono import"
assert any(
"use uuid::Uuid" in imp for imp in imports
), "Should include UUID import"
assert any(
"use serde_json::Value" in imp for imp in imports
), "Should include JSON import"
assert any(
"use rust_decimal::Decimal" in imp for imp in imports
), "Should include Decimal import"
def test_get_required_imports_with_no_table(self):
class ModelWithoutTable:
pass
imports = get_required_imports(ModelWithoutTable)
assert isinstance(imports, set), "Result should be a set"
assert (
len(imports) == 0
), "Should return empty set for model without __table__ attribute"
def test_get_required_imports_with_enum_column(self):
class TestEnum(str, Enum):
A = "a"
B = "b"
class ModelWithEnum(BaseModel):
__abstract__ = False
id = mapped_column(Integer, primary_key=True)
enum_col = mapped_column(SQLAlchemyEnum(TestEnum), nullable=False)
imports = get_required_imports(ModelWithEnum)
assert "use crate::models::TestEnum;" in imports
assert "use chrono::{DateTime, Utc};" in imports
assert "use uuid::Uuid;" not in imports
assert "use rust_decimal::Decimal;" not in imports
def test_get_required_imports_with_enum_column_no_enum_class(self):
class ModelWithEnumNoEnumClass(BaseModel):
__abstract__ = False
__table_args__ = {"extend_existing": True}
id = mapped_column(Integer, primary_key=True)
enum_col = mapped_column(Integer, nullable=False)
created_at = mapped_column(DateTime, nullable=False)
id_column_mock = MagicMock()
id_column_mock.name = "id"
id_column_mock.type.__class__.__name__ = "Integer"
datetime_column_mock = MagicMock()
datetime_column_mock.name = "created_at"
datetime_column_mock.type.__class__.__name__ = "DateTime"
enum_column_mock = MagicMock()
enum_column_mock.name = "enum_col"
enum_column_mock.type.__class__.__name__ = "Enum"
type(enum_column_mock.type).enum_class = PropertyMock(
side_effect=AttributeError("No enum_class")
)
with patch.object(
ModelWithEnumNoEnumClass.__table__,
"columns",
[id_column_mock, datetime_column_mock, enum_column_mock],
):
imports = get_required_imports(ModelWithEnumNoEnumClass)
assert not any("use crate::models::" in imp for imp in imports)
assert "use chrono::{DateTime, Utc};" in imports
def test_generate_rust_model(self, test_rust_model):
result = generate_rust_model(test_rust_model)
assert "pub struct RustModelExample {" in result
assert "pub id: Uuid," in result
assert "pub name: String," in result
assert "pub age: i32," in result
assert "pub is_active: bool," in result
assert "pub created_at: DateTime<Utc>," in result
assert "pub model_id: Uuid," in result
assert "pub data: JsonValue," in result
assert "pub amount: f64," in result
assert "pub decimal_value: Decimal," in result
assert "pub date_value: NaiveDate," in result
assert "pub time_value: NaiveTime," in result
assert "pub interval_value: chrono::Duration," in result
assert "pub updated_at: DateTime<Utc>," in result
assert "pub created_by: String," in result
assert "pub updated_by: String," in result
assert "impl RustModelExample {" in result
assert "pub fn new(" in result
assert "name: String," in result
assert "age: i32," in result
assert "is_active: bool," in result
assert "created_at: DateTime<Utc>," in result
assert "model_id: Uuid," in result
assert "data: JsonValue," in result
assert "amount: f64," in result
assert "decimal_value: Decimal," in result
assert "date_value: NaiveDate," in result
assert "time_value: NaiveTime," in result
assert "interval_value: chrono::Duration," in result
assert "id: Uuid," in result
assert "updated_at: DateTime<Utc>," in result
assert "created_by: String," in result
assert "updated_by: String," in result
assert ") -> Self {" in result
assert "Self {" in result
def test_generate_rust_model_abstract_class(self):
result = generate_rust_model(AbstractModel)
assert result == "", "Should return empty string for abstract classes"
def test_generate_rust_model_without_table(self):
class ModelWithoutTable:
__name__ = "ModelWithoutTable"
result = generate_rust_model(ModelWithoutTable)
assert (
result == ""
), "Should return empty string for models without __table__ attribute"
def test_generate_rust_model_with_imports(self, test_rust_model):
template_content = """
// Model template
{imports}
/// {model_doc}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct {model_name} {
{fields}
}
impl {model_name} {
pub fn new({constructor_args}) -> Self {
Self {
{constructor_body}
}
}
}
"""
imports = {
"use chrono::{DateTime, Utc};",
"use uuid::Uuid;",
"use serde_json::Value as JsonValue;",
"use rust_decimal::Decimal;",
"use custom::Type;", }
with patch("builtins.open", mock_open(read_data=template_content)):
with patch("pathlib.Path.open", mock_open(read_data=template_content)):
with patch(
"nomy_data_models.py_to_rust.get_required_imports",
return_value=imports,
):
result = generate_rust_model(test_rust_model)
assert "use custom::Type;" in result, "Should include custom import"
def test_generate_rust_models(self):
with tempfile.TemporaryDirectory() as temp_dir:
with patch(
"nomy_data_models.py_to_rust.get_all_models",
return_value={"TestRustModel": RustModelExample},
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_model",
return_value="// Mock Rust model",
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_enum",
return_value="// Mock Rust enum",
):
with patch("os.makedirs"):
with patch("builtins.open", mock_open()):
generate_rust_models(temp_dir)
assert True, "Function should run without errors"
def test_generate_rust_models_with_default_output_dir(self):
with patch(
"nomy_data_models.py_to_rust.get_all_models",
return_value={"TestRustModel": RustModelExample},
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_model",
return_value="// Mock Rust model",
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_enum",
return_value="// Mock Rust enum",
):
with patch("os.makedirs"):
with patch("builtins.open", mock_open()):
generate_rust_models()
assert (
True
), "Function should run without errors with default output_dir"
def test_generate_rust_models_with_empty_model_output(self):
with tempfile.TemporaryDirectory() as temp_dir:
with patch(
"nomy_data_models.py_to_rust.get_all_models",
return_value={"AbstractModel": AbstractModel},
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_model", return_value=""
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_enum",
return_value="// Mock Rust enum",
):
with patch("os.makedirs"):
with patch("builtins.open", mock_open()):
generate_rust_models(temp_dir)
assert (
True
), "Function should run without errors when generate_rust_model returns empty string"
def test_get_all_models(self):
mock_base_model = MagicMock()
mock_base_model.__name__ = "BaseModel"
mock_concrete_model = MagicMock()
mock_concrete_model.__name__ = "ConcreteModel"
mock_concrete_model.__module__ = "nomy_data_models.models.concrete_model"
mock_concrete_model.__mro__ = (mock_concrete_model, mock_base_model, object)
mock_abstract_model = MagicMock()
mock_abstract_model.__name__ = "AbstractModel"
mock_abstract_model.__abstract__ = True
mock_abstract_model.__module__ = "nomy_data_models.models.abstract_model"
mock_abstract_model.__mro__ = (mock_abstract_model, mock_base_model, object)
mock_concrete_abstract_model = MagicMock()
mock_concrete_abstract_model.__name__ = "ConcreteAbstractModel"
mock_concrete_abstract_model.__abstract__ = False
mock_concrete_abstract_model.__module__ = (
"nomy_data_models.models.concrete_abstract_model"
)
mock_concrete_abstract_model.__mro__ = (
mock_concrete_abstract_model,
mock_abstract_model,
mock_base_model,
object,
)
mock_model_with_table_args = MagicMock()
mock_model_with_table_args.__name__ = "ModelWithTableArgs"
mock_model_with_table_args.__table_args__ = {"sqlite_autoincrement": True}
mock_model_with_table_args.__module__ = (
"nomy_data_models.models.model_with_table_args"
)
mock_model_with_table_args.__mro__ = (
mock_model_with_table_args,
mock_base_model,
object,
)
mock_external_model = MagicMock()
mock_external_model.__name__ = "ExternalModel"
mock_external_model.__module__ = "external_package.models"
mock_external_model.__mro__ = (mock_external_model, mock_base_model, object)
model_names = [
"BaseModel",
"ConcreteModel",
"AbstractModel",
"ConcreteAbstractModel",
"ModelWithTableArgs",
"ExternalModel",
]
def simulate_get_all_models():
result = {}
for name in model_names:
if name == "BaseModel":
item = mock_base_model
elif name == "ConcreteModel":
item = mock_concrete_model
elif name == "AbstractModel":
item = mock_abstract_model
elif name == "ConcreteAbstractModel":
item = mock_concrete_abstract_model
elif name == "ModelWithTableArgs":
item = mock_model_with_table_args
elif name == "ExternalModel":
item = mock_external_model
else:
continue
if hasattr(item, "__mro__") and mock_base_model in item.__mro__:
is_abstract = getattr(item, "__abstract__", False)
if (
item != mock_base_model
and not is_abstract
and item.__module__.startswith("nomy_data_models.models")
):
result[name] = item
return result
result = simulate_get_all_models()
assert isinstance(result, dict), "Result should be a dictionary"
assert "ConcreteModel" in result, "Should include concrete models"
assert (
"ConcreteAbstractModel" in result
), "Should include concrete models that inherit from abstract models"
assert (
"ModelWithTableArgs" in result
), "Should include models with __table_args__"
assert (
"AbstractModel" not in result
), "Should not include abstract models without table attributes"
assert "BaseModel" not in result, "Should not include BaseModel"
def test_get_all_models_with_real_models(self):
result = get_all_models()
assert isinstance(result, dict), "Result should be a dictionary"
assert len(result) > 0, "Should find at least one model"
from nomy_data_models.models.base import BaseModel
assert "BaseModel" not in result, "Should not include BaseModel"
for name, model in result.items():
assert issubclass(
model, BaseModel
), f"Model {name} should be a SQLAlchemy model"
assert model.__module__.startswith(
"nomy_data_models.models"
), f"Model {name} should be from nomy_data_models.models package"
def test_generate_rust_model_with_unknown_type(self):
mock_model = MagicMock()
mock_model.__name__ = "MockModel"
mock_model.__doc__ = "A mock model with an unknown column type."
mock_model.__abstract__ = False
mock_column = MagicMock()
mock_column.name = "unknown_column"
mock_column.type = MagicMock()
type(mock_column.type).__name__ = "UnknownType"
mock_table = MagicMock()
mock_table.columns = [mock_column]
mock_model.__table__ = mock_table
import io
import sys
captured_output = io.StringIO()
sys.stdout = captured_output
try:
result = generate_rust_model(mock_model)
warning_message = captured_output.getvalue()
assert "Warning: Unknown SQLAlchemy type UnknownType" in warning_message
assert "defaulting to String" in warning_message
assert "pub unknown_column: String," in result
finally:
sys.stdout = sys.__stdout__
def test_print_unknown_type_warning(self):
import io
import sys
captured_output = io.StringIO()
sys.stdout = captured_output
try:
_print_unknown_type_warning("TestType")
warning_message = captured_output.getvalue()
assert "Warning: Unknown SQLAlchemy type TestType" in warning_message
assert "defaulting to String" in warning_message
finally:
sys.stdout = sys.__stdout__
def test_generate_rust_model_fields(self):
mock_model = MagicMock()
mock_model.__name__ = "MockModel"
mock_model.__doc__ = "A mock model for testing field generation."
mock_model.__abstract__ = False
mock_column = MagicMock()
mock_column.name = "test_field"
mock_column.type = MagicMock()
type(mock_column.type).__name__ = "String"
mock_table = MagicMock()
mock_table.columns = [mock_column]
mock_model.__table__ = mock_table
result = generate_rust_model(mock_model)
assert "pub test_field: String," in result
def test_generate_rust_model_multiple_fields(self):
mock_model = MagicMock()
mock_model.__name__ = "MultiFieldModel"
mock_model.__doc__ = "A mock model with multiple fields for testing."
mock_model.__abstract__ = False
columns = []
for i, type_name in enumerate(
["String", "Integer", "Boolean", "DateTime", "UUID"]
):
mock_column = MagicMock()
mock_column.name = f"field_{i}"
mock_column.type = MagicMock()
type(mock_column.type).__name__ = type_name
columns.append(mock_column)
mock_table = MagicMock()
mock_table.columns = columns
mock_model.__table__ = mock_table
result = generate_rust_model(mock_model)
assert "pub field_0: String," in result
assert "pub field_1: i32," in result
assert "pub field_2: bool," in result
assert "pub field_3: DateTime<Utc>," in result
assert "pub field_4: Uuid," in result
def test_generate_rust_fields_function(self):
columns = [
("id", "Uuid"),
("name", "String"),
("age", "i32"),
("is_active", "bool"),
("created_at", "DateTime<Utc>"),
]
fields = _generate_rust_fields(columns)
assert len(fields) == 5, "Should generate 5 field definitions"
assert " pub id: Uuid," in fields
assert " pub name: String," in fields
assert " pub age: i32," in fields
assert " pub is_active: bool," in fields
assert " pub created_at: DateTime<Utc>," in fields
def test_generate_rust_model_with_none_type(self):
mock_model = MagicMock()
mock_model.__name__ = "MockModelWithNoneType"
mock_model.__doc__ = "A mock model for testing None type handling."
mock_model.__abstract__ = False
mock_column = MagicMock()
mock_column.name = "none_type_field"
mock_column.type = MagicMock()
type(mock_column.type).__name__ = "NoneType"
mock_table = MagicMock()
mock_table.columns = [mock_column]
mock_model.__table__ = mock_table
with patch(
"nomy_data_models.py_to_rust.sqlalchemy_to_rust_type",
side_effect=lambda t: None if t == "NoneType" else "String",
):
import io
import sys
captured_output = io.StringIO()
sys.stdout = captured_output
try:
result = generate_rust_model(mock_model)
warning_message = captured_output.getvalue()
assert "Warning: Unknown SQLAlchemy type NoneType" in warning_message
assert "defaulting to String" in warning_message
assert "pub none_type_field: String," in result
finally:
sys.stdout = sys.__stdout__
def test_generate_rust_model_with_multiline_docstring(self):
class ModelWithMultilineDoc(BaseModel):
__abstract__ = False
id = mapped_column(Integer, primary_key=True)
name = mapped_column(String(50), nullable=False)
result = generate_rust_model(ModelWithMultilineDoc)
assert "/// This is a model with a multi-line docstring." in result
assert "It has multiple lines." in result
assert "And even more lines." in result
def test_generate_rust_model_with_annotations(self):
class ModelWithAnnotations:
__name__ = "ModelWithAnnotations"
id: int
name: str
is_active: bool
created_at: "datetime"
model_id: "UUID"
amount: float
decimal_value: "Decimal"
data: dict
result = generate_rust_model(ModelWithAnnotations)
assert "pub struct ModelWithAnnotations {" in result
assert "pub id: i32," in result
assert "pub name: String," in result
assert "pub is_active: bool," in result
assert "pub created_at: DateTime<Utc>," in result
assert "pub model_id: Uuid," in result
assert "pub amount: f64," in result
assert "pub decimal_value: Decimal," in result
assert "pub data: JsonValue," in result
def test_generate_rust_models_file_writing(self, tmp_path):
with patch(
"nomy_data_models.py_to_rust.get_all_models",
return_value={"TestModel1": MagicMock(), "TestModel2": MagicMock()},
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_model",
return_value="// Mock Rust model",
):
with patch(
"nomy_data_models.py_to_rust.generate_rust_enum",
return_value="// Mock Rust enum",
):
with patch("pathlib.Path.mkdir"):
with patch("builtins.open", mock_open()) as mock_file:
generate_rust_models(output_dir=str(tmp_path))
mock_file.assert_called()
mock_file().write.assert_called()
def test_get_all_models_with_external_module(self):
from nomy_data_models.models.base import BaseModel
class ExternalModel(BaseModel):
__module__ = "external_package.models"
__abstract__ = False
with patch("nomy_data_models.py_to_rust.dir", return_value=["ExternalModel"]):
with patch(
"nomy_data_models.py_to_rust.getattr",
side_effect=lambda module, name: (
ExternalModel if name == "ExternalModel" else getattr(module, name)
),
):
result = get_all_models()
assert isinstance(result, dict), "Result should be a dictionary"
assert (
"ExternalModel" not in result
), "Should not include models from external packages"
if __name__ == "__main__":
pytest.main(["-v", __file__])