use std::collections::VecDeque;
use futures::stream::{BoxStream, StreamExt};
use serde::{Deserialize, Serialize};
use crate::council::agent::{Agent, AgentError, ChatRequest, ChatToken, ChatTokenStream};
use crate::council::event::ExpertId;
#[derive(Debug)]
pub struct OpenAiHttpAgent {
id: ExpertId,
model: String,
timeout_ms: u64,
endpoint: String,
api_key: Option<String>,
client: reqwest::Client,
}
impl OpenAiHttpAgent {
pub fn new(
id: impl Into<ExpertId>,
model: impl Into<String>,
endpoint: impl Into<String>,
timeout_ms: u64,
api_key: Option<String>,
) -> Result<Self, AgentError> {
let id_str = id.into();
let client = reqwest::Client::builder()
.build()
.map_err(|e| AgentError::Transport {
agent_id: id_str.clone(),
message: format!("client build failed: {e}"),
})?;
Ok(Self {
id: id_str,
model: model.into(),
timeout_ms,
endpoint: endpoint.into(),
api_key,
client,
})
}
}
impl Agent for OpenAiHttpAgent {
fn id(&self) -> &ExpertId {
&self.id
}
fn model(&self) -> &str {
&self.model
}
fn timeout_ms(&self) -> u64 {
self.timeout_ms
}
fn chat(&self, request: ChatRequest) -> ChatTokenStream {
let body = build_body(&request);
let url = format!("{}/v1/chat/completions", self.endpoint.trim_end_matches('/'));
let id = self.id.clone();
let client = self.client.clone();
let api_key = self.api_key.clone();
let setup = async move {
let mut req = client.post(&url).json(&body);
if let Some(k) = api_key {
req = req.bearer_auth(k);
}
match req.send().await {
Ok(resp) if resp.status().is_success() => {
translate_sse(id, resp.bytes_stream().boxed())
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
single_transport_err(id, format!("http {status}: {text}"))
}
Err(e) => single_transport_err(id, format!("request failed: {e}")),
}
};
Box::pin(futures::stream::once(setup).flatten())
}
}
#[derive(Serialize)]
struct ChatCompletionRequestBody {
model: String,
messages: Vec<WireMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u64>,
}
#[derive(Serialize)]
struct WireMessage {
role: &'static str,
content: String,
}
fn build_body(req: &ChatRequest) -> ChatCompletionRequestBody {
ChatCompletionRequestBody {
model: req.model.clone(),
messages: req
.messages
.iter()
.map(|m| WireMessage {
role: m.role.as_wire_str(),
content: m.content.clone(),
})
.collect(),
stream: true,
temperature: req.sampling.temperature,
top_p: req.sampling.top_p,
max_tokens: req.sampling.max_tokens,
seed: req.sampling.seed,
}
}
#[derive(Deserialize)]
struct ChatCompletionChunk {
#[serde(default)]
choices: Vec<ChunkChoice>,
}
#[derive(Deserialize)]
struct ChunkChoice {
#[serde(default)]
delta: ChunkDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize, Default)]
struct ChunkDelta {
#[serde(default)]
content: Option<String>,
}
struct SseState {
inner: BoxStream<'static, reqwest::Result<bytes::Bytes>>,
buf: String,
pending: VecDeque<Result<ChatToken, AgentError>>,
done: bool,
agent_id: ExpertId,
}
fn translate_sse(
agent_id: ExpertId,
bytes_stream: BoxStream<'static, reqwest::Result<bytes::Bytes>>,
) -> ChatTokenStream {
let state = SseState {
inner: bytes_stream,
buf: String::new(),
pending: VecDeque::new(),
done: false,
agent_id,
};
let stream = futures::stream::unfold(state, |mut state| async move {
if let Some(item) = state.pending.pop_front() {
return Some((item, state));
}
if state.done {
return None;
}
loop {
match state.inner.next().await {
Some(Ok(bytes)) => {
let s = match std::str::from_utf8(&bytes) {
Ok(s) => s.to_string(),
Err(e) => {
state.done = true;
return Some((
Err(AgentError::Stream {
agent_id: state.agent_id.clone(),
message: format!("non-utf8 sse frame: {e}"),
}),
state,
));
}
};
state.buf.push_str(&s);
while let Some(end) = state.buf.find("\n\n") {
let frame: String = state.buf.drain(..end + 2).collect();
match parse_frame(&frame) {
FrameOutcome::Done => {
state.done = true;
return state.pending.pop_front().map(|i| (i, state));
}
FrameOutcome::Token(text, finish_reason) => {
let finished = finish_reason.is_some();
if !text.is_empty() || finished {
state.pending.push_back(Ok(ChatToken {
text,
finished,
finish_reason,
}));
}
}
FrameOutcome::Skip => {}
FrameOutcome::ParseError(msg) => {
state.done = true;
state.pending.push_back(Err(AgentError::Stream {
agent_id: state.agent_id.clone(),
message: msg,
}));
}
}
}
if let Some(item) = state.pending.pop_front() {
return Some((item, state));
}
}
Some(Err(e)) => {
state.done = true;
return Some((
Err(AgentError::Stream {
agent_id: state.agent_id.clone(),
message: format!("sse read error: {e}"),
}),
state,
));
}
None => {
return None;
}
}
}
});
Box::pin(stream)
}
enum FrameOutcome {
Done,
Token(String, Option<String>),
Skip,
ParseError(String),
}
fn parse_frame(frame: &str) -> FrameOutcome {
let mut data = String::new();
for line in frame.lines() {
let line = line.trim_start();
if let Some(rest) = line.strip_prefix("data:") {
data.push_str(rest.trim_start());
data.push('\n');
}
}
let data = data.trim();
if data.is_empty() {
return FrameOutcome::Skip;
}
if data == "[DONE]" {
return FrameOutcome::Done;
}
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(chunk) => {
let (text, finish_reason) = chunk
.choices
.into_iter()
.next()
.map(|c| (c.delta.content.unwrap_or_default(), c.finish_reason))
.unwrap_or_default();
FrameOutcome::Token(text, finish_reason)
}
Err(e) => FrameOutcome::ParseError(format!("bad sse json: {e}; body was `{data}`")),
}
}
fn single_transport_err(agent_id: ExpertId, message: String) -> ChatTokenStream {
Box::pin(futures::stream::iter(vec![Err(AgentError::Transport {
agent_id,
message,
})]))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::council::agent::{ChatMessage, ChatRole};
fn req() -> ChatRequest {
ChatRequest {
model: "m".into(),
messages: vec![ChatMessage {
role: ChatRole::User,
content: "hi".into(),
}],
sampling: Default::default(),
request_id: None,
}
}
async fn collect(stream: ChatTokenStream) -> Vec<Result<ChatToken, String>> {
let mut s = stream;
let mut out = Vec::new();
while let Some(item) = s.next().await {
out.push(item.map_err(|e| e.to_string()));
}
out
}
fn sse_bytes(frames: &[&str]) -> BoxStream<'static, reqwest::Result<bytes::Bytes>> {
let payload: Vec<u8> = frames
.iter()
.flat_map(|f| format!("{f}\n\n").into_bytes())
.collect();
let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> =
vec![Ok(bytes::Bytes::from(payload))];
futures::stream::iter(chunks).boxed()
}
fn sse_bytes_split(frames: &[&str], chunk_size: usize) -> BoxStream<'static, reqwest::Result<bytes::Bytes>> {
let payload: Vec<u8> = frames
.iter()
.flat_map(|f| format!("{f}\n\n").into_bytes())
.collect();
let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> = payload
.chunks(chunk_size)
.map(|c| Ok(bytes::Bytes::copy_from_slice(c)))
.collect();
futures::stream::iter(chunks).boxed()
}
#[tokio::test]
async fn parses_openai_sse_tokens_in_order() {
let frames = [
r#"data: {"choices":[{"delta":{"content":"hello"}}]}"#,
r#"data: {"choices":[{"delta":{"content":" "}}]}"#,
r#"data: {"choices":[{"delta":{"content":"world"},"finish_reason":"stop"}]}"#,
r#"data: [DONE]"#,
];
let stream = translate_sse("A".into(), sse_bytes(&frames));
let items = collect(stream).await;
let texts: Vec<_> = items
.iter()
.filter_map(|r| r.as_ref().ok().map(|t| t.text.clone()))
.collect();
assert_eq!(texts, vec!["hello", " ", "world"]);
assert!(items.last().unwrap().as_ref().unwrap().finished);
}
#[tokio::test]
async fn parses_frames_split_across_byte_chunks() {
let frames = [
r#"data: {"choices":[{"delta":{"content":"hello"}}]}"#,
r#"data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}]}"#,
r#"data: [DONE]"#,
];
let stream = translate_sse("A".into(), sse_bytes_split(&frames, 4));
let items = collect(stream).await;
let texts: Vec<_> = items
.iter()
.filter_map(|r| r.as_ref().ok().map(|t| t.text.clone()))
.collect();
assert_eq!(texts, vec!["hello", " world"]);
}
#[tokio::test]
async fn ignores_empty_and_comment_frames() {
let frames = [
r#": this is a comment"#,
r#"data: {"choices":[{"delta":{"content":"x"}}]}"#,
r#"data: [DONE]"#,
];
let stream = translate_sse("A".into(), sse_bytes(&frames));
let items = collect(stream).await;
let texts: Vec<_> = items
.iter()
.filter_map(|r| r.as_ref().ok().map(|t| t.text.clone()))
.collect();
assert_eq!(texts, vec!["x"]);
}
#[tokio::test]
async fn parse_error_surfaces_as_stream_error() {
let frames = [r#"data: not-json"#, r#"data: [DONE]"#];
let stream = translate_sse("A".into(), sse_bytes(&frames));
let items = collect(stream).await;
assert!(matches!(items[0], Err(_)), "got {items:?}");
}
#[test]
fn body_serializes_with_stream_true_and_sampling_knobs() {
let mut r = req();
r.sampling.temperature = Some(0.7);
r.sampling.max_tokens = Some(128);
let body = build_body(&r);
let json = serde_json::to_string(&body).unwrap();
assert!(json.contains(r#""stream":true"#));
assert!(json.contains(r#""temperature":0.7"#));
assert!(json.contains(r#""max_tokens":128"#));
assert!(!json.contains("top_p"), "omitted knobs must not serialize");
}
use axum::response::sse::{Event, Sse};
use axum::routing::post;
use axum::Router;
use std::convert::Infallible;
use std::time::Duration;
use tokio::sync::oneshot;
async fn fixed_stream_handler() -> Sse<impl futures::Stream<Item = Result<Event, Infallible>>>
{
let events = vec![
Ok(Event::default().data(r#"{"choices":[{"delta":{"content":"alpha"}}]}"#)),
Ok(Event::default().data(r#"{"choices":[{"delta":{"content":"-beta"},"finish_reason":"stop"}]}"#)),
Ok(Event::default().data("[DONE]")),
];
Sse::new(futures::stream::iter(events))
}
async fn spawn_mock_openai() -> (std::net::SocketAddr, oneshot::Sender<()>) {
let app = Router::new().route("/v1/chat/completions", post(fixed_stream_handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let _ = axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await;
});
tokio::time::sleep(Duration::from_millis(30)).await;
(addr, shutdown_tx)
}
#[tokio::test]
async fn http_agent_streams_tokens_from_openai_compatible_server() {
let (addr, shutdown) = spawn_mock_openai().await;
let endpoint = format!("http://{addr}");
let agent = OpenAiHttpAgent::new("A", "test-model", endpoint, 5_000, None).unwrap();
let mut stream = agent.chat(req());
let mut texts = Vec::new();
let mut last_finished = false;
while let Some(item) = stream.next().await {
let tok = item.expect("ok");
last_finished = tok.finished;
texts.push(tok.text);
}
assert_eq!(texts, vec!["alpha", "-beta"]);
assert!(last_finished);
let _ = shutdown.send(());
}
#[tokio::test]
async fn http_agent_reports_transport_error_for_non_2xx() {
let app = Router::new().route(
"/v1/chat/completions",
post(|| async { (axum::http::StatusCode::UNAUTHORIZED, "nope") }),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let _ = axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await;
});
tokio::time::sleep(Duration::from_millis(30)).await;
let endpoint = format!("http://{addr}");
let agent = OpenAiHttpAgent::new("A", "m", endpoint, 5_000, None).unwrap();
let mut stream = agent.chat(req());
let item = stream.next().await.expect("one item");
match item {
Err(AgentError::Transport { agent_id, message }) => {
assert_eq!(agent_id, "A");
assert!(message.contains("401"), "got {message}");
}
other => panic!("expected Transport, got {other:?}"),
}
let _ = shutdown_tx.send(());
}
}