import uuid
from datetime import datetime, timezone
from decimal import Decimal
import pytest
from pytest import approx
from sqlalchemy.orm import Session
from nomy_data_models.models.enums import (
MarketType,
PositionDirection,
PositionStatus,
PositionTradeType,
)
from nomy_data_models.models.position import Position
from nomy_data_models.models.position_trade import PositionTrade
from nomy_data_models.models.raw_trade import RawTrade
@pytest.fixture
def sample_position(session: Session) -> Position:
position = Position(
position_id=uuid.uuid4(),
chain_id=1,
exchange="test_exchange",
market_type=MarketType.SPOT,
position_direction=PositionDirection.BUY,
wallet_address="0xTestWalletAddress0000000000000000000000",
token_symbol_pair="TST/USD",
base_token_symbol="TST",
quote_token_symbol="USD",
status=PositionStatus.OPEN,
current_base_amount=Decimal("1.5"),
original_base_amount=Decimal("1.0"),
avg_entry_price=Decimal("100.0"),
current_average_entry_price=Decimal("100.0"),
avg_exit_price=Decimal("0.0"),
cost_basis=Decimal("100.0"),
opened_at=datetime.now(timezone.utc),
)
session.add(position)
session.commit()
session.refresh(position)
return position
@pytest.fixture
def opening_position_trade_data(
sample_position: Position, persisted_raw_trade: RawTrade
) -> dict:
return {
"position_id": sample_position.id,
"raw_trade_id": persisted_raw_trade.id,
"raw_trade_event_at": persisted_raw_trade.event_at,
"trade_type": PositionTradeType.OPENING,
"amount": Decimal("1.0"),
"price": Decimal("100.0"),
"event_at": datetime.now(timezone.utc),
"fees_total": Decimal("0.1"),
"unmatched_amount": Decimal("1.0"),
"is_fully_matched": False,
"is_taker": False,
"extra_data": {"source": "test"},
}
@pytest.fixture
def closing_position_trade_data(
sample_position: Position, persisted_raw_trade: RawTrade
) -> dict:
return {
"position_id": sample_position.id,
"raw_trade_id": persisted_raw_trade.id,
"raw_trade_event_at": persisted_raw_trade.event_at,
"trade_type": PositionTradeType.CLOSING,
"amount": Decimal("0.5"),
"price": Decimal("120.0"),
"event_at": datetime.now(timezone.utc),
"fees_total": Decimal("0.05"),
"unmatched_amount": Decimal("0.5"),
"is_fully_matched": False,
"is_taker": True,
}
class TestPositionTrade:
def test_position_trade_creation(self, opening_position_trade_data: dict) -> None:
trade = PositionTrade(**opening_position_trade_data)
assert trade.id is None assert trade.position_id == opening_position_trade_data["position_id"]
assert trade.raw_trade_id == opening_position_trade_data["raw_trade_id"]
assert (
trade.raw_trade_event_at
== opening_position_trade_data["raw_trade_event_at"]
)
assert trade.trade_type == PositionTradeType.OPENING
assert trade.amount == Decimal("1.0")
assert trade.price == Decimal("100.0")
assert trade.event_at == opening_position_trade_data["event_at"]
assert trade.fees_total == Decimal("0.1")
assert trade.unmatched_amount == Decimal("1.0")
assert not trade.is_fully_matched
assert not trade.is_taker
assert trade.extra_data == {"source": "test"}
def test_position_trade_db_integration(
self, session: Session, opening_position_trade_data: dict
) -> None:
trade_data = opening_position_trade_data
trade = PositionTrade(**trade_data)
session.add(trade)
session.commit()
assert trade.id is not None
retrieved = session.query(PositionTrade).filter_by(id=trade.id).first()
assert retrieved is not None
assert retrieved.id == trade.id
assert retrieved.position_id == trade_data["position_id"]
assert retrieved.raw_trade_id == trade_data["raw_trade_id"]
assert retrieved.raw_trade_event_at == trade_data["raw_trade_event_at"]
assert retrieved.trade_type == PositionTradeType.OPENING
assert retrieved.amount == approx(Decimal("1.0"))
assert retrieved.price == approx(Decimal("100.0"))
assert retrieved.event_at.replace(tzinfo=timezone.utc) == trade_data["event_at"]
assert retrieved.fees_total == approx(Decimal("0.1"))
assert retrieved.unmatched_amount == approx(Decimal("1.0"))
assert retrieved.position is not None
assert retrieved.position.id == trade_data["position_id"]
def test_closing_position_trade(
self, session: Session, closing_position_trade_data: dict
) -> None:
trade_data = closing_position_trade_data
trade = PositionTrade(**trade_data)
session.add(trade)
session.commit()
retrieved = session.query(PositionTrade).filter_by(id=trade.id).first()
assert retrieved is not None
assert retrieved.raw_trade_id == trade_data["raw_trade_id"]
assert retrieved.raw_trade_event_at == trade_data["raw_trade_event_at"]
assert retrieved.trade_type == PositionTradeType.CLOSING
assert retrieved.amount == Decimal("0.5")
assert retrieved.price == Decimal("120.0")
assert retrieved.is_taker