import asyncio
import json
import pytest
from unittest.mock import AsyncMock
from forge.context.manager import ContextManager
from forge.context.strategies import NoCompact
from forge.core.workflow import TextResponse, ToolCall
from forge.proxy.server import HTTPServer
def _mock_client(response):
client = AsyncMock()
client.api_format = "ollama"
client.send = AsyncMock(return_value=response)
return client
@pytest.fixture
async def server_factory():
servers = []
async def _make(response, serialize=False):
client = _mock_client(response)
ctx = ContextManager(strategy=NoCompact(), budget_tokens=8192)
srv = HTTPServer(
client=client,
context_manager=ctx,
host="127.0.0.1",
port=0, serialize_requests=serialize,
)
await srv.start()
sock = srv._server.sockets[0]
port = sock.getsockname()[1]
servers.append(srv)
return srv, port
yield _make
for srv in servers:
await srv.stop()
async def _http_request(port, method, path, body=None):
reader, writer = await asyncio.open_connection("127.0.0.1", port)
try:
if body is not None:
body_bytes = json.dumps(body).encode()
request = (
f"{method} {path} HTTP/1.1\r\n"
f"Host: 127.0.0.1:{port}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(body_bytes)}\r\n"
f"\r\n"
).encode() + body_bytes
else:
request = (
f"{method} {path} HTTP/1.1\r\n"
f"Host: 127.0.0.1:{port}\r\n"
f"\r\n"
).encode()
writer.write(request)
await writer.drain()
response_data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
response_str = response_data.decode("utf-8", errors="replace")
lines = response_str.split("\r\n")
status = int(lines[0].split(" ", 2)[1])
body_start = response_str.find("\r\n\r\n")
response_body = response_str[body_start + 4:] if body_start >= 0 else ""
return status, response_body
finally:
writer.close()
await writer.wait_closed()
async def _sse_request(port, body):
reader, writer = await asyncio.open_connection("127.0.0.1", port)
try:
body_bytes = json.dumps(body).encode()
request = (
f"POST /v1/chat/completions HTTP/1.1\r\n"
f"Host: 127.0.0.1:{port}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(body_bytes)}\r\n"
f"\r\n"
).encode() + body_bytes
writer.write(request)
await writer.drain()
response_data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
response_str = response_data.decode("utf-8", errors="replace")
data_lines = []
for line in response_str.split("\n"):
line = line.strip()
if line.startswith("data: "):
data_lines.append(line[6:])
return data_lines
finally:
writer.close()
await writer.wait_closed()
class TestHealthAndModels:
@pytest.mark.asyncio
async def test_health_endpoint(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
status, body = await _http_request(port, "GET", "/health")
assert status == 200
data = json.loads(body)
assert data["status"] == "ok"
@pytest.mark.asyncio
async def test_models_endpoint(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
status, body = await _http_request(port, "GET", "/v1/models")
assert status == 200
data = json.loads(body)
assert data["object"] == "list"
assert len(data["data"]) > 0
@pytest.mark.asyncio
async def test_not_found(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
status, body = await _http_request(port, "GET", "/nonexistent")
assert status == 404
@pytest.mark.asyncio
async def test_cors_preflight(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
status, _ = await _http_request(port, "OPTIONS", "/v1/chat/completions")
assert status == 204
class TestChatCompletions:
@pytest.mark.asyncio
async def test_no_tools_text_response(self, server_factory):
srv, port = await server_factory(TextResponse(content="Hello!"))
body = {"messages": [{"role": "user", "content": "hi"}]}
status, response_body = await _http_request(
port, "POST", "/v1/chat/completions", body,
)
assert status == 200
data = json.loads(response_body)
assert data["choices"][0]["message"]["content"] == "Hello!"
@pytest.mark.asyncio
async def test_tool_call_response(self, server_factory):
srv, port = await server_factory(
[ToolCall(tool="search", args={"q": "test"})],
)
body = {
"messages": [{"role": "user", "content": "search for test"}],
"tools": [{
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object", "properties": {}},
},
}],
}
status, response_body = await _http_request(
port, "POST", "/v1/chat/completions", body,
)
assert status == 200
data = json.loads(response_body)
tc = data["choices"][0]["message"]["tool_calls"]
assert len(tc) == 1
assert tc[0]["function"]["name"] == "search"
@pytest.mark.asyncio
async def test_invalid_json_returns_400(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
reader, writer = await asyncio.open_connection("127.0.0.1", port)
try:
bad_body = b"not json"
request = (
f"POST /v1/chat/completions HTTP/1.1\r\n"
f"Host: 127.0.0.1:{port}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(bad_body)}\r\n"
f"\r\n"
).encode() + bad_body
writer.write(request)
await writer.drain()
response_data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
assert b"400" in response_data
finally:
writer.close()
await writer.wait_closed()
@pytest.mark.asyncio
async def test_invalid_content_length_returns_400(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
reader, writer = await asyncio.open_connection("127.0.0.1", port)
try:
request = (
f"POST /v1/chat/completions HTTP/1.1\r\n"
f"Host: 127.0.0.1:{port}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: abc\r\n"
f"\r\n"
).encode()
writer.write(request)
await writer.drain()
response_data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
assert b"400" in response_data
finally:
writer.close()
await writer.wait_closed()
@pytest.mark.asyncio
async def test_non_object_body_returns_400(self, server_factory):
srv, port = await server_factory(TextResponse(content=""))
status, _ = await _http_request(
port, "POST", "/v1/chat/completions", body=[1, 2, 3]
)
assert status == 400
class TestSSEStreaming:
@pytest.mark.asyncio
async def test_streaming_text_response(self, server_factory):
srv, port = await server_factory(TextResponse(content="Hello!"))
body = {
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
}
data_lines = await _sse_request(port, body)
assert "[DONE]" in data_lines
json_events = [json.loads(d) for d in data_lines if d != "[DONE]"]
assert len(json_events) > 0
@pytest.mark.asyncio
async def test_streaming_tool_call(self, server_factory):
srv, port = await server_factory(
[ToolCall(tool="search", args={"q": "x"})],
)
body = {
"messages": [{"role": "user", "content": "go"}],
"tools": [{
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object", "properties": {}},
},
}],
"stream": True,
}
data_lines = await _sse_request(port, body)
assert "[DONE]" in data_lines
json_events = [json.loads(d) for d in data_lines if d != "[DONE]"]
has_tool_call = any(
"tool_calls" in e["choices"][0].get("delta", {})
for e in json_events
)
assert has_tool_call
class TestSerialization:
@pytest.mark.asyncio
async def test_serialized_requests_processed(self, server_factory):
srv, port = await server_factory(
TextResponse(content="ok"), serialize=True,
)
body = {"messages": [{"role": "user", "content": "hi"}]}
status, response_body = await _http_request(
port, "POST", "/v1/chat/completions", body,
)
assert status == 200
data = json.loads(response_body)
assert data["choices"][0]["message"]["content"] == "ok"