import uuid
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import pytest
from pytest import approx
from sqlalchemy.orm import Session
from nomy_data_models.models.enums import (
MarketType,
PositionDirection,
PositionStatus,
PositionTradeType,
)
from nomy_data_models.models.position import Position
from nomy_data_models.models.position_trade import PositionTrade
from nomy_data_models.models.raw_trade import RawTrade
from nomy_data_models.models.trade_match import TradeMatch
@pytest.fixture
def sample_position_for_match(session: Session) -> Position:
position = Position(
position_id=uuid.uuid4(),
chain_id=1,
exchange="test_exchange_match",
market_type=MarketType.SPOT,
position_direction=PositionDirection.BUY,
wallet_address="0xMatchWalletAddress000000000000000000000",
token_symbol_pair="MTC/USD",
base_token_symbol="MTC",
quote_token_symbol="USD",
status=PositionStatus.OPEN,
current_base_amount=Decimal("1.0"), original_base_amount=Decimal("1.0"),
avg_entry_price=Decimal("200.0"),
current_average_entry_price=Decimal("200.0"),
avg_exit_price=Decimal("0.0"),
cost_basis=Decimal("200.0"),
opened_at=datetime.now(timezone.utc) - timedelta(days=1),
)
session.add(position)
session.commit()
session.refresh(position)
return position
@pytest.fixture
def persisted_opening_trade(
session: Session,
sample_position_for_match: Position,
persisted_raw_trade: RawTrade,
) -> PositionTrade:
trade = PositionTrade(
position_id=sample_position_for_match.id,
raw_trade_id=persisted_raw_trade.id,
raw_trade_event_at=persisted_raw_trade.event_at,
trade_type=PositionTradeType.OPENING,
amount=Decimal("1.0"),
price=Decimal("200.0"),
event_at=sample_position_for_match.opened_at,
fees_total=Decimal("0.2"),
unmatched_amount=Decimal("1.0"),
is_fully_matched=False,
)
session.add(trade)
session.commit()
session.refresh(trade)
return trade
@pytest.fixture
def persisted_closing_trade(
session: Session,
sample_position_for_match: Position,
persisted_raw_trade: RawTrade,
) -> PositionTrade:
trade = PositionTrade(
position_id=sample_position_for_match.id,
raw_trade_id=persisted_raw_trade.id,
raw_trade_event_at=persisted_raw_trade.event_at,
trade_type=PositionTradeType.CLOSING,
amount=Decimal("0.8"),
price=Decimal("250.0"),
event_at=datetime.now(timezone.utc),
fees_total=Decimal("0.25"),
unmatched_amount=Decimal("0.8"),
is_fully_matched=False,
)
session.add(trade)
session.commit()
session.refresh(trade)
return trade
@pytest.fixture
def trade_match_data(
sample_position_for_match: Position,
persisted_opening_trade: PositionTrade,
persisted_closing_trade: PositionTrade,
) -> dict:
matched_amount = Decimal("0.8")
entry_price = persisted_opening_trade.price
exit_price = persisted_closing_trade.price
pnl = (exit_price - entry_price) * matched_amount pnl_usd = pnl
return {
"position_id": sample_position_for_match.id,
"opening_trade_id": persisted_opening_trade.id,
"closing_trade_id": persisted_closing_trade.id,
"matched_amount": matched_amount,
"entry_price": entry_price,
"exit_price": exit_price,
"pnl": pnl,
"pnl_usd": pnl_usd,
"roi": (
(pnl / (entry_price * matched_amount))
if entry_price > 0 and matched_amount > 0
else Decimal("0")
),
"holding_duration_seconds": (
persisted_closing_trade.event_at - persisted_opening_trade.event_at
).total_seconds(),
}
class TestTradeMatch:
def test_trade_match_creation(self, trade_match_data: dict) -> None:
match = TradeMatch(**trade_match_data)
assert match.id is None assert match.position_id == trade_match_data["position_id"]
assert match.opening_trade_id == trade_match_data["opening_trade_id"]
assert match.closing_trade_id == trade_match_data["closing_trade_id"]
assert match.matched_amount == Decimal("0.8")
assert match.entry_price == Decimal("200.0")
assert match.exit_price == Decimal("250.0")
assert match.pnl == Decimal("40.0") assert match.pnl_usd == Decimal("40.0")
assert match.roi == Decimal("0.25") assert match.holding_duration_seconds > 0
def test_trade_match_db_integration(
self, session: Session, trade_match_data: dict
) -> None:
match_data = trade_match_data
match = TradeMatch(**match_data)
session.add(match)
session.commit()
assert match.id is not None
assert match.match_created_at is not None
retrieved = session.query(TradeMatch).filter_by(id=match.id).first()
assert retrieved is not None
assert retrieved.id == match.id
assert retrieved.position_id == match_data["position_id"]
assert retrieved.opening_trade_id == match_data["opening_trade_id"]
assert retrieved.closing_trade_id == match_data["closing_trade_id"]
assert retrieved.matched_amount == approx(Decimal("0.8"))
assert retrieved.pnl_usd == approx(Decimal("40.0"))
assert retrieved.entry_price == approx(match_data["entry_price"])
assert retrieved.exit_price == approx(match_data["exit_price"])
assert retrieved.pnl == approx(match_data["pnl"])
assert retrieved.roi == approx(match_data["roi"])
assert retrieved.holding_duration_seconds == approx(
match_data["holding_duration_seconds"]
)
assert retrieved.position is not None
assert retrieved.position.id == match_data["position_id"]
assert retrieved.opening_trade is not None
assert retrieved.opening_trade.id == match_data["opening_trade_id"]
assert retrieved.closing_trade is not None
assert retrieved.closing_trade.id == match_data["closing_trade_id"]