import json
import os
import sys
import uuid
from datetime import datetime, timezone
from decimal import Decimal
from pathlib import Path
import pytest
from sqlalchemy import MetaData, String, Text, TypeDecorator, create_engine, inspect
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session, clear_mappers, sessionmaker
sys.path.append(str(Path(__file__).parent.parent.parent))
from nomy_data_models.models.base import Base
from nomy_data_models.models.enriched_trade import EnrichedTrade
from nomy_data_models.models.position import (
MarketType,
Position,
PositionDirection,
PositionStatus,
)
from nomy_data_models.models.raw_trade import RawTrade
from nomy_data_models.models.service_state import ServiceState
from nomy_data_models.models.wallet_state import DataState, SyncState, WalletState
@pytest.fixture(scope="session")
def setup_session():
pass
@pytest.fixture(scope="function")
def setup_teardown():
yield
clear_mappers()
class JsonbSQLite(TypeDecorator):
impl = Text
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
return json.dumps(value)
return None
def process_result_value(self, value, dialect):
if value is not None:
return json.loads(value)
return None
@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(element, compiler, **kw):
return "TEXT"
class UuidSQLite(TypeDecorator):
impl = String(36)
cache_ok = True
@compiles(UUID, "sqlite")
def compile_uuid_sqlite(element, compiler, **kw):
return "VARCHAR(36)"
@pytest.fixture(scope="function")
def in_memory_db():
engine = create_engine("sqlite:///:memory:", echo=False)
metadata = MetaData()
for table in Base.metadata.tables.values():
new_table = table.tometadata(metadata)
for column in new_table.columns:
if isinstance(column.type, JSONB):
column.type = JsonbSQLite()
elif isinstance(column.type, UUID):
column.type = UuidSQLite()
metadata.create_all(engine, checkfirst=True)
return engine
@pytest.fixture(scope="function")
def session(in_memory_db):
SessionFactory = sessionmaker(bind=in_memory_db)
session = SessionFactory()
yield session
session.close()
@pytest.fixture(scope="function")
def persisted_raw_trade(session: Session) -> RawTrade:
base_amount = Decimal("1.0")
token_price = Decimal("100.0")
raw_trade = RawTrade(
txn_id=f"raw_{uuid.uuid4().hex}",
event_at=datetime.now(timezone.utc),
chain_id=1,
exchange="TestExchange",
market_type=MarketType.SPOT,
token_symbol_pair="TEST/USD",
base_token_symbol="TEST",
quote_token_symbol="USD",
wallet_address="0xTestWalletAddress" + "0" * 30,
base_amount=base_amount,
quote_amount=base_amount * token_price,
token_price=token_price,
extra_data={"source": "test_fixture"},
)
session.add(raw_trade)
session.commit()
session.refresh(raw_trade)
return raw_trade