nomy-data-models 0.35.6

Data model definitions for Nomy wallet analysis data processing
Documentation
"""Pytest configuration for the nomy_data_models package."""

import json
import os
import sys
import uuid
from datetime import datetime, timezone
from decimal import Decimal
from pathlib import Path

import pytest
from sqlalchemy import MetaData, String, Text, TypeDecorator, create_engine, inspect
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session, clear_mappers, sessionmaker

# Add the parent directory to the path so we can import the package
sys.path.append(str(Path(__file__).parent.parent.parent))

# Import all models to ensure they are registered with SQLAlchemy
from nomy_data_models.models.base import Base
from nomy_data_models.models.enriched_trade import EnrichedTrade
from nomy_data_models.models.position import (
    MarketType,
    Position,
    PositionDirection,
    PositionStatus,
)
from nomy_data_models.models.raw_trade import RawTrade
from nomy_data_models.models.service_state import ServiceState
from nomy_data_models.models.wallet_state import DataState, SyncState, WalletState


@pytest.fixture(scope="session")
def setup_session():
    """Set up the test session."""
    # This is run once for the entire test session
    pass


@pytest.fixture(scope="function")
def setup_teardown():
    """Set up and tear down for each test function."""
    yield
    clear_mappers()


# SQLite-compatible JSONB type
class JsonbSQLite(TypeDecorator):
    """SQLite adaptation of PostgreSQL's JSONB type."""

    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        """Convert dict to JSON string when storing."""
        if value is not None:
            return json.dumps(value)
        return None

    def process_result_value(self, value, dialect):
        """Convert JSON string to dict when retrieving."""
        if value is not None:
            return json.loads(value)
        return None


# Register compiler for JSONB to use our custom type for SQLite
@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(element, compiler, **kw):
    """Compile JSONB as TEXT for SQLite."""
    return "TEXT"


# SQLite-compatible UUID type
class UuidSQLite(TypeDecorator):
    """SQLite adaptation of PostgreSQL's UUID type."""

    impl = String(36)
    cache_ok = True


# Register compiler for UUID to use our custom type for SQLite
@compiles(UUID, "sqlite")
def compile_uuid_sqlite(element, compiler, **kw):
    """Compile UUID as TEXT for SQLite."""
    return "VARCHAR(36)"


@pytest.fixture(scope="function")
def in_memory_db():
    """Create an in-memory SQLite database for testing."""
    # Create an in-memory SQLite database
    engine = create_engine("sqlite:///:memory:", echo=False)

    # Create a new metadata with copied tables
    metadata = MetaData()

    # Replace PostgreSQL-specific types with SQLite-compatible ones
    for table in Base.metadata.tables.values():
        new_table = table.tometadata(metadata)
        for column in new_table.columns:
            if isinstance(column.type, JSONB):
                column.type = JsonbSQLite()
            elif isinstance(column.type, UUID):
                column.type = UuidSQLite()

    # Create all tables with the modified metadata
    metadata.create_all(engine, checkfirst=True)

    return engine


@pytest.fixture(scope="function")
def session(in_memory_db):
    """Create a new SQLAlchemy session for each test function."""
    # Create a new session
    SessionFactory = sessionmaker(bind=in_memory_db)
    session = SessionFactory()

    yield session

    # Close the session
    session.close()


@pytest.fixture(scope="function")
def persisted_raw_trade(session: Session) -> RawTrade:
    """Fixture providing a persisted RawTrade instance."""
    base_amount = Decimal("1.0")
    token_price = Decimal("100.0")
    raw_trade = RawTrade(
        txn_id=f"raw_{uuid.uuid4().hex}",
        event_at=datetime.now(timezone.utc),
        chain_id=1,
        exchange="TestExchange",
        market_type=MarketType.SPOT,
        token_symbol_pair="TEST/USD",
        base_token_symbol="TEST",
        quote_token_symbol="USD",
        wallet_address="0xTestWalletAddress" + "0" * 30,
        base_amount=base_amount,
        quote_amount=base_amount * token_price,
        token_price=token_price,
        extra_data={"source": "test_fixture"},
    )
    session.add(raw_trade)
    session.commit()
    session.refresh(raw_trade)
    return raw_trade