nomy-data-models 0.35.6

Data model definitions for Nomy wallet analysis data processing
Documentation
"""Tests for the TradeMatch model."""

import uuid
from datetime import datetime, timedelta, timezone
from decimal import Decimal

import pytest
from pytest import approx
from sqlalchemy.orm import Session

from nomy_data_models.models.enums import (
    MarketType,
    PositionDirection,
    PositionStatus,
    PositionTradeType,
)
from nomy_data_models.models.position import Position
from nomy_data_models.models.position_trade import PositionTrade
from nomy_data_models.models.raw_trade import RawTrade
from nomy_data_models.models.trade_match import TradeMatch


@pytest.fixture
def sample_position_for_match(session: Session) -> Position:
    """Fixture providing a persisted Position instance for matching tests."""
    position = Position(
        position_id=uuid.uuid4(),
        chain_id=1,
        exchange="test_exchange_match",
        market_type=MarketType.SPOT,
        position_direction=PositionDirection.BUY,
        wallet_address="0xMatchWalletAddress000000000000000000000",
        token_symbol_pair="MTC/USD",
        base_token_symbol="MTC",
        quote_token_symbol="USD",
        status=PositionStatus.OPEN,
        current_base_amount=Decimal("1.0"),  # Assume starting with 1 unit
        original_base_amount=Decimal("1.0"),
        avg_entry_price=Decimal("200.0"),
        current_avg_entry_price=Decimal("200.0"),
        avg_exit_price=Decimal("0.0"),
        cost_basis=Decimal("200.0"),
        opened_at=datetime.now(timezone.utc) - timedelta(days=1),
    )
    session.add(position)
    session.commit()
    session.refresh(position)
    return position


@pytest.fixture
def persisted_opening_trade(
    session: Session,
    sample_position_for_match: Position,
    persisted_raw_trade: RawTrade,
) -> PositionTrade:
    """Fixture providing a persisted opening PositionTrade."""
    trade = PositionTrade(
        position_id=sample_position_for_match.id,
        raw_trade_id=persisted_raw_trade.id,
        raw_trade_txn_id=persisted_raw_trade.txn_id,
        trade_type=PositionTradeType.OPENING,
        amount=Decimal("1.0"),
        price=Decimal("200.0"),
        event_at=sample_position_for_match.opened_at,
        fees_total=Decimal("0.2"),
        unmatched_amount=Decimal("1.0"),
        is_fully_matched=False,
    )
    session.add(trade)
    session.commit()
    session.refresh(trade)
    return trade


@pytest.fixture
def persisted_closing_trade(
    session: Session,
    sample_position_for_match: Position,
    persisted_raw_trade: RawTrade,
) -> PositionTrade:
    """Fixture providing a persisted closing PositionTrade."""
    trade = PositionTrade(
        position_id=sample_position_for_match.id,
        raw_trade_id=persisted_raw_trade.id,
        raw_trade_txn_id=persisted_raw_trade.txn_id,
        trade_type=PositionTradeType.CLOSING,
        amount=Decimal("0.8"),
        price=Decimal("250.0"),
        event_at=datetime.now(timezone.utc),
        fees_total=Decimal("0.25"),
        unmatched_amount=Decimal("0.8"),
        is_fully_matched=False,
    )
    session.add(trade)
    session.commit()
    session.refresh(trade)
    return trade


@pytest.fixture
def trade_match_data(
    sample_position_for_match: Position,
    persisted_opening_trade: PositionTrade,
    persisted_closing_trade: PositionTrade,
) -> dict:
    """Provides data for a TradeMatch instance."""
    matched_amount = Decimal("0.8")
    entry_price = persisted_opening_trade.price
    exit_price = persisted_closing_trade.price
    pnl = (exit_price - entry_price) * matched_amount  # Simple PnL for BUY direction
    pnl_usd = pnl  # Assuming quote currency is USD

    # Duration in seconds
    duration_seconds = (
        persisted_closing_trade.event_at - persisted_opening_trade.event_at
    ).total_seconds()

    return {
        "position_id": sample_position_for_match.id,
        "opening_trade_id": persisted_opening_trade.id,
        "closing_trade_id": persisted_closing_trade.id,
        "matched_amount": matched_amount,
        "entry_price": entry_price,
        "exit_price": exit_price,
        "pnl": pnl,
        "pnl_usd": pnl_usd,
        "roi": (
            (pnl / (entry_price * matched_amount))
            if entry_price > 0 and matched_amount > 0
            else Decimal("0")
        ),
        "holding_duration": timedelta(seconds=duration_seconds),
        # match_created_at has a default
    }


class TestTradeMatch:
    """Test cases for the TradeMatch model."""

    def test_trade_match_creation(self, trade_match_data: dict) -> None:
        """Test creating a TradeMatch instance."""
        match = TradeMatch(**trade_match_data)

        assert match.id is None  # ID is set upon DB insertion
        assert match.position_id == trade_match_data["position_id"]
        assert match.opening_trade_id == trade_match_data["opening_trade_id"]
        assert match.closing_trade_id == trade_match_data["closing_trade_id"]
        assert match.matched_amount == Decimal("0.8")
        assert match.entry_price == Decimal("200.0")
        assert match.exit_price == Decimal("250.0")
        assert match.pnl == Decimal("40.0")  # (250 - 200) * 0.8
        assert match.pnl_usd == Decimal("40.0")
        assert match.roi == Decimal("0.25")  # 40 / (200 * 0.8)
        assert match.holding_duration is not None
        # match_created_at is defaulted

    def test_trade_match_db_integration(
        self, session: Session, trade_match_data: dict
    ) -> None:
        """Test saving and retrieving a TradeMatch instance."""
        match_data = trade_match_data
        match = TradeMatch(**match_data)

        session.add(match)
        session.commit()

        assert match.id is not None
        assert match.match_created_at is not None

        retrieved = session.query(TradeMatch).filter_by(id=match.id).first()
        assert retrieved is not None
        assert retrieved.id == match.id
        assert retrieved.position_id == match_data["position_id"]
        assert retrieved.opening_trade_id == match_data["opening_trade_id"]
        assert retrieved.closing_trade_id == match_data["closing_trade_id"]
        assert retrieved.matched_amount == approx(Decimal("0.8"))
        assert retrieved.pnl_usd == approx(Decimal("40.0"))
        assert retrieved.entry_price == approx(match_data["entry_price"])
        assert retrieved.exit_price == approx(match_data["exit_price"])
        assert retrieved.pnl == approx(match_data["pnl"])
        assert retrieved.roi == approx(match_data["roi"])
        assert isinstance(retrieved.holding_duration, timedelta)

        # Test relationships
        assert retrieved.position is not None
        assert retrieved.position.id == match_data["position_id"]
        assert retrieved.opening_trade is not None
        assert retrieved.opening_trade.id == match_data["opening_trade_id"]
        assert retrieved.closing_trade is not None
        assert retrieved.closing_trade.id == match_data["closing_trade_id"]