use super::super::chat_completion::ChatCompletionRequest;
use super::super::options::LLMEventHandlers;
use super::super::types::{Endpoint, LLMChunkResponse, Method};
use crate::Error;
use futures::stream::{Stream, StreamExt};
use serde::Serialize;
use std::time::Duration;
impl super::PlainLLM {
pub(super) async fn http_call(
&self,
endpoint: Endpoint,
method: Method,
request_body: Option<impl Serialize>,
) -> Result<String, Error> {
let uri = format!("{}/{}", self.api_url, endpoint.http_uri());
tracing::info!("HTTP {:?} {}", method, uri);
if let Ok(body) = serde_json::to_string(&request_body) {
tracing::debug!("request body: {}", body);
}
let req = match method {
Method::Post => self.http_client.post(&uri).json(&request_body.unwrap()),
Method::Get => self.http_client.get(&uri),
}
.header("Authorization", format!("Bearer {}", self.token))
.timeout(Duration::from_secs(300));
let response = req.send().await.map_err(Error::Http)?;
let status = response.status();
let text = response.text().await.map_err(Error::Http)?;
tracing::info!("status {}", status.as_u16());
tracing::debug!("response body: {}", text);
if status.is_success() {
Ok(text)
} else {
Err(Error::HttpStatus(status, text))
}
}
pub(super) async fn http_call_streamed(
&self,
endpoint: Endpoint,
method: Method,
request_body: Option<impl Serialize>,
) -> Result<impl Stream<Item = Result<String, reqwest::Error>>, Error> {
let uri = format!("{}/{}", self.api_url, endpoint.http_uri());
tracing::info!("HTTP streaming {:?} {}", method, uri);
if let Ok(body) = serde_json::to_string(&request_body) {
tracing::debug!("request body: {}", body);
}
let req = match method {
Method::Post => self.http_client.post(&uri).json(&request_body.unwrap()),
Method::Get => self.http_client.get(&uri),
}
.header("Authorization", format!("Bearer {}", self.token))
.timeout(Duration::from_secs(300));
let response = req.send().await.map_err(Error::Http)?;
let status = response.status();
tracing::info!("status {}", status.as_u16());
if status.is_success() {
let stream = response.bytes_stream().map(|result| {
result.map(|chunk| {
let chunk_str = String::from_utf8_lossy(&chunk).to_string();
chunk_str.replace("data: ", "")
})
});
Ok(stream)
} else {
let text = response.text().await.map_err(Error::Http)?;
tracing::debug!("response body: {}", text);
Err(Error::HttpStatus(status, text))
}
}
pub(super) async fn stream_llm(
&self,
request: &ChatCompletionRequest,
handlers: &LLMEventHandlers,
) -> Result<(Vec<LLMChunkResponse>, String), Error> {
tracing::info!("stream_llm start");
let stream = self
.http_call_streamed(Endpoint::ChatCompletion, Method::Post, Some(request))
.await?;
let mut raw_chunks = Vec::new();
let mut partial_content = String::new();
let mut buffer = String::new();
let mut in_think = false;
let mut in_reasoning = false;
fn process_buffer(buffer: &mut String, in_think: &mut bool, handlers: &LLMEventHandlers) {
const THINK_OPEN: &str = "<think>";
const THINK_CLOSE: &str = "</think>";
loop {
if *in_think {
if let Some(end) = buffer.find(THINK_CLOSE) {
let text = &buffer[..end];
if let Some(ref cb) = handlers.on_thinking {
if !text.is_empty() {
cb(text);
}
}
if let Some(ref cb) = handlers.on_stop_thinking {
cb();
}
buffer.drain(..end + THINK_CLOSE.len());
*in_think = false;
} else {
if buffer.len() > THINK_CLOSE.len() {
let flush_chars =
buffer.chars().count().saturating_sub(THINK_CLOSE.len());
let flush_byte_idx = buffer
.char_indices()
.nth(flush_chars)
.map(|(idx, _)| idx)
.unwrap_or(buffer.len());
let text = buffer[..flush_byte_idx].to_string();
if let Some(ref cb) = handlers.on_thinking {
if !text.is_empty() {
cb(&text);
}
}
buffer.drain(..flush_byte_idx);
}
break;
}
} else if let Some(start) = buffer.find(THINK_OPEN) {
let text = &buffer[..start];
if let Some(ref cb) = handlers.on_token {
if !text.is_empty() {
cb(text);
}
}
if let Some(ref cb) = handlers.on_start_thinking {
cb();
}
buffer.drain(..start + THINK_OPEN.len());
*in_think = true;
} else {
if let Some(pos) = buffer.rfind('<') {
if pos > 0 {
let text = &buffer[..pos];
if let Some(ref cb) = handlers.on_token {
if !text.is_empty() {
cb(text);
}
}
buffer.drain(..pos);
}
break;
} else {
if let Some(ref cb) = handlers.on_token {
if !buffer.is_empty() {
cb(buffer);
}
}
buffer.clear();
break;
}
}
}
}
futures::pin_mut!(stream);
let mut sse_buffer = String::new();
'outer: while let Some(chunk_result) = stream.next().await {
let chunk_str = chunk_result?;
if chunk_str.trim().is_empty() {
continue;
}
tracing::trace!("raw chunk: {}", chunk_str);
sse_buffer.push_str(&chunk_str);
while let Some(idx) = sse_buffer.find("\n\n") {
let mut event = sse_buffer[..idx].to_string();
sse_buffer.drain(..idx + 2);
event = event.trim().trim_start_matches("data: ").to_string();
if event == "[DONE]" {
break 'outer;
}
match serde_json::from_str::<LLMChunkResponse>(&event) {
Ok(chunk) => {
raw_chunks.push(chunk.clone());
let reasoning_text = chunk
.choices
.get(0)
.and_then(|cc| cc.delta.reasoning_content.clone());
let content_text =
chunk.choices.get(0).and_then(|cc| cc.delta.content.clone());
if let Some(text) = reasoning_text {
if !in_reasoning {
if let Some(ref cb) = handlers.on_start_thinking {
cb();
}
in_reasoning = true;
in_think = true;
}
if let Some(ref cb) = handlers.on_thinking {
if !text.is_empty() {
cb(&text);
}
}
}
if let Some(token_text) = content_text {
if in_reasoning {
if let Some(ref cb) = handlers.on_stop_thinking {
cb();
}
in_reasoning = false;
in_think = false;
}
partial_content.push_str(&token_text);
buffer.push_str(&token_text);
process_buffer(&mut buffer, &mut in_think, handlers);
}
if let Some(reason) =
chunk.choices.get(0).and_then(|cc| cc.finish_reason.clone())
{
if reason == "tool_calls" || reason == "stop" || reason == "length" {
break 'outer;
}
}
}
Err(e) => {
tracing::warn!("Failed to parse chunk as JSON: {} -- raw: {}", e, event);
}
}
}
}
if !buffer.is_empty() {
if in_think {
if let Some(ref cb) = handlers.on_thinking {
cb(&buffer);
}
} else if let Some(ref cb) = handlers.on_token {
cb(&buffer);
}
}
if in_reasoning {
if let Some(ref cb) = handlers.on_stop_thinking {
cb();
}
}
tracing::info!("stream complete");
tracing::trace!("stream complete; {} chunks collected", raw_chunks.len());
Ok((raw_chunks, partial_content))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::types::Message;
use futures::StreamExt;
use serde_json::json;
fn start_server(
body: &'static [u8],
status: u16,
content_type: Option<&str>,
) -> (std::net::SocketAddr, std::thread::JoinHandle<()>) {
use std::io::{Read, Write};
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let body_vec = body.to_vec();
let ct_header = content_type
.map(|ct| format!("Content-Type: {}\r\n", ct))
.unwrap_or_default();
let handle = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut _buf = [0u8; 1024];
let _ = stream.read(&mut _buf); let status_line = match status {
200 => "200 OK",
500 => "500 Internal Server Error",
_ => "200 OK",
};
let response = format!(
"HTTP/1.1 {}\r\nContent-Length: {}\r\n{}\r\n",
status_line,
body_vec.len(),
ct_header
);
stream.write_all(response.as_bytes()).unwrap();
stream.write_all(&body_vec).unwrap();
});
(addr, handle)
}
#[tokio::test]
async fn http_call_success() {
let (addr, handle) = start_server(b"ok", 200, None);
let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
let empty = json!({});
let res = llm
.http_call(Endpoint::ChatCompletion, Method::Post, Some(&empty))
.await
.unwrap();
handle.join().unwrap();
assert_eq!(res, "ok");
}
#[tokio::test]
async fn http_call_error() {
let (addr, handle) = start_server(b"fail", 500, None);
let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
let empty = json!({});
let err = llm
.http_call(Endpoint::ChatCompletion, Method::Post, Some(&empty))
.await
.unwrap_err();
handle.join().unwrap();
match err {
crate::Error::HttpStatus(code, body) => {
assert_eq!(code.as_u16(), 500);
assert_eq!(body, "fail");
}
_ => panic!("unexpected error"),
}
}
#[tokio::test]
async fn http_call_streamed_returns_chunks() {
const BODY: &str = "data: {\"a\":1}\n\ndata: [DONE]\n\n";
let (addr, handle) = start_server(BODY.as_bytes(), 200, Some("text/event-stream"));
let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
let empty = json!({});
let mut stream = llm
.http_call_streamed(Endpoint::ChatCompletion, Method::Post, Some(&empty))
.await
.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await {
collected.push(chunk.unwrap());
}
handle.join().unwrap();
assert_eq!(collected, vec!["{\"a\":1}\n\n[DONE]\n\n"]);
}
#[tokio::test]
async fn stream_llm_parses_sse() {
const BODY: &str = "data: {\"id\":\"1\",\"object\":\"chunk\",\"created\":0,\"model\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n";
let (addr, handle) = start_server(BODY.as_bytes(), 200, Some("text/event-stream"));
let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
let mut req = ChatCompletionRequest::new("m".into(), vec![Message::new("user", "hi")]);
req.stream = true;
let (chunks, content) = llm
.stream_llm(&req, &LLMEventHandlers::default())
.await
.unwrap();
handle.join().unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(content, "hi");
}
}