nomy-data-models 0.33.0

Data model definitions for Nomy wallet analysis data processing
Documentation
"""Tests for the PositionTrade model."""

import uuid
from datetime import datetime, timezone
from decimal import Decimal

import pytest
from pytest import approx
from sqlalchemy.orm import Session

from nomy_data_models.models.enums import (
    MarketType,
    PositionDirection,
    PositionStatus,
    PositionTradeType,
)
from nomy_data_models.models.position import Position
from nomy_data_models.models.position_trade import PositionTrade
from nomy_data_models.models.raw_trade import RawTrade


@pytest.fixture
def sample_position(session: Session) -> Position:
    """Fixture providing a persisted Position instance."""
    position = Position(
        position_id=uuid.uuid4(),
        chain_id=1,
        exchange="test_exchange",
        market_type=MarketType.SPOT,
        position_direction=PositionDirection.BUY,
        wallet_address="0xTestWalletAddress0000000000000000000000",
        token_symbol_pair="TST/USD",
        base_token_symbol="TST",
        quote_token_symbol="USD",
        status=PositionStatus.OPEN,
        current_base_amount=Decimal("1.5"),
        original_base_amount=Decimal("1.0"),
        avg_entry_price=Decimal("100.0"),
        current_average_entry_price=Decimal("100.0"),
        avg_exit_price=Decimal("0.0"),
        cost_basis=Decimal("100.0"),
        opened_at=datetime.now(timezone.utc),
    )
    session.add(position)
    session.commit()
    session.refresh(position)
    return position


@pytest.fixture
def opening_position_trade_data(
    sample_position: Position, persisted_raw_trade: RawTrade
) -> dict:
    """Provides data for an opening PositionTrade instance."""
    return {
        "position_id": sample_position.id,
        "raw_trade_id": persisted_raw_trade.id,
        "raw_trade_event_at": persisted_raw_trade.event_at,
        "trade_type": PositionTradeType.OPENING,
        "amount": Decimal("1.0"),
        "price": Decimal("100.0"),
        "event_at": datetime.now(timezone.utc),
        "fees_total": Decimal("0.1"),
        "unmatched_amount": Decimal("1.0"),
        "is_fully_matched": False,
        "is_taker": False,
        "extra_data": {"source": "test"},
    }


@pytest.fixture
def closing_position_trade_data(
    sample_position: Position, persisted_raw_trade: RawTrade
) -> dict:
    """Provides data for a closing PositionTrade instance."""
    return {
        "position_id": sample_position.id,
        "raw_trade_id": persisted_raw_trade.id,
        "raw_trade_event_at": persisted_raw_trade.event_at,
        "trade_type": PositionTradeType.CLOSING,
        "amount": Decimal("0.5"),
        "price": Decimal("120.0"),
        "event_at": datetime.now(timezone.utc),
        "fees_total": Decimal("0.05"),
        "unmatched_amount": Decimal("0.5"),
        "is_fully_matched": False,
        "is_taker": True,
    }


class TestPositionTrade:
    """Test cases for the PositionTrade model."""

    def test_position_trade_creation(self, opening_position_trade_data: dict) -> None:
        """Test creating a PositionTrade instance."""
        trade = PositionTrade(**opening_position_trade_data)

        assert trade.id is None  # ID is set upon DB insertion
        assert trade.position_id == opening_position_trade_data["position_id"]
        assert trade.raw_trade_id == opening_position_trade_data["raw_trade_id"]
        assert (
            trade.raw_trade_event_at
            == opening_position_trade_data["raw_trade_event_at"]
        )
        assert trade.trade_type == PositionTradeType.OPENING
        assert trade.amount == Decimal("1.0")
        assert trade.price == Decimal("100.0")
        assert trade.event_at == opening_position_trade_data["event_at"]
        assert trade.fees_total == Decimal("0.1")
        assert trade.unmatched_amount == Decimal("1.0")
        assert not trade.is_fully_matched
        assert not trade.is_taker
        assert trade.extra_data == {"source": "test"}

    def test_position_trade_db_integration(
        self, session: Session, opening_position_trade_data: dict
    ) -> None:
        """Test saving and retrieving a PositionTrade instance."""
        trade_data = opening_position_trade_data
        trade = PositionTrade(**trade_data)

        session.add(trade)
        session.commit()

        assert trade.id is not None  # Verify ID is set after commit

        retrieved = session.query(PositionTrade).filter_by(id=trade.id).first()
        assert retrieved is not None
        assert retrieved.id == trade.id
        assert retrieved.position_id == trade_data["position_id"]
        assert retrieved.raw_trade_id == trade_data["raw_trade_id"]
        assert retrieved.raw_trade_event_at == trade_data["raw_trade_event_at"]
        assert retrieved.trade_type == PositionTradeType.OPENING
        assert retrieved.amount == approx(Decimal("1.0"))
        assert retrieved.price == approx(Decimal("100.0"))
        assert retrieved.event_at.replace(tzinfo=timezone.utc) == trade_data["event_at"]
        assert retrieved.fees_total == approx(Decimal("0.1"))
        assert retrieved.unmatched_amount == approx(Decimal("1.0"))

        # Test relationship back to Position
        assert retrieved.position is not None
        assert retrieved.position.id == trade_data["position_id"]

    def test_closing_position_trade(
        self, session: Session, closing_position_trade_data: dict
    ) -> None:
        """Test creating and saving a closing PositionTrade."""
        trade_data = closing_position_trade_data
        trade = PositionTrade(**trade_data)

        session.add(trade)
        session.commit()

        retrieved = session.query(PositionTrade).filter_by(id=trade.id).first()
        assert retrieved is not None
        assert retrieved.raw_trade_id == trade_data["raw_trade_id"]
        assert retrieved.raw_trade_event_at == trade_data["raw_trade_event_at"]
        assert retrieved.trade_type == PositionTradeType.CLOSING
        assert retrieved.amount == Decimal("0.5")
        assert retrieved.price == Decimal("120.0")
        assert retrieved.is_taker