plainllm 1.2.0

A plain & simple LLM client
Documentation
use super::super::chat_completion::ChatCompletionRequest;
use super::super::options::LLMEventHandlers;
use super::super::types::{Endpoint, LLMChunkResponse, Method};
use crate::Error;
use futures::stream::{Stream, StreamExt};
use serde::Serialize;
use std::time::Duration;
impl super::PlainLLM {
    //////////////////////////////////////////
    //  Low-level HTTP calls
    //////////////////////////////////////////
    pub(super) async fn http_call(
        &self,
        endpoint: Endpoint,
        method: Method,
        request_body: Option<impl Serialize>,
    ) -> Result<String, Error> {
        let uri = format!("{}/{}", self.api_url, endpoint.http_uri());
        tracing::info!("HTTP {:?} {}", method, uri);
        if let Ok(body) = serde_json::to_string(&request_body) {
            tracing::debug!("request body: {}", body);
        }
        let req = match method {
            Method::Post => self.http_client.post(&uri).json(&request_body.unwrap()),
            Method::Get => self.http_client.get(&uri),
        }
        .header("Authorization", format!("Bearer {}", self.token))
        .timeout(Duration::from_secs(300));

        let response = req.send().await.map_err(Error::Http)?;
        let status = response.status();
        let text = response.text().await.map_err(Error::Http)?;
        tracing::info!("status {}", status.as_u16());
        tracing::debug!("response body: {}", text);
        if status.is_success() {
            Ok(text)
        } else {
            Err(Error::HttpStatus(status, text))
        }
    }

    pub(super) async fn http_call_streamed(
        &self,
        endpoint: Endpoint,
        method: Method,
        request_body: Option<impl Serialize>,
    ) -> Result<impl Stream<Item = Result<String, reqwest::Error>>, Error> {
        let uri = format!("{}/{}", self.api_url, endpoint.http_uri());
        tracing::info!("HTTP streaming {:?} {}", method, uri);
        if let Ok(body) = serde_json::to_string(&request_body) {
            tracing::debug!("request body: {}", body);
        }
        let req = match method {
            Method::Post => self.http_client.post(&uri).json(&request_body.unwrap()),
            Method::Get => self.http_client.get(&uri),
        }
        .header("Authorization", format!("Bearer {}", self.token))
        .timeout(Duration::from_secs(300));

        let response = req.send().await.map_err(Error::Http)?;
        let status = response.status();
        tracing::info!("status {}", status.as_u16());
        if status.is_success() {
            let stream = response.bytes_stream().map(|result| {
                result.map(|chunk| {
                    let chunk_str = String::from_utf8_lossy(&chunk).to_string();
                    chunk_str.replace("data: ", "")
                })
            });
            Ok(stream)
        } else {
            let text = response.text().await.map_err(Error::Http)?;
            tracing::debug!("response body: {}", text);
            Err(Error::HttpStatus(status, text))
        }
    }

    pub(super) async fn stream_llm(
        &self,
        request: &ChatCompletionRequest,
        handlers: &LLMEventHandlers,
    ) -> Result<(Vec<LLMChunkResponse>, String), Error> {
        tracing::info!("stream_llm start");
        let stream = self
            .http_call_streamed(Endpoint::ChatCompletion, Method::Post, Some(request))
            .await?;

        let mut raw_chunks = Vec::new();
        let mut partial_content = String::new();
        let mut buffer = String::new();
        let mut in_think = false;
        let mut in_reasoning = false;

        fn process_buffer(buffer: &mut String, in_think: &mut bool, handlers: &LLMEventHandlers) {
            const THINK_OPEN: &str = "<think>";
            const THINK_CLOSE: &str = "</think>";

            loop {
                if *in_think {
                    if let Some(end) = buffer.find(THINK_CLOSE) {
                        let text = &buffer[..end];
                        if let Some(ref cb) = handlers.on_thinking {
                            if !text.is_empty() {
                                cb(text);
                            }
                        }
                        if let Some(ref cb) = handlers.on_stop_thinking {
                            cb();
                        }
                        buffer.drain(..end + THINK_CLOSE.len());
                        *in_think = false;
                    } else {
                        if buffer.len() > THINK_CLOSE.len() {
                            let flush_chars =
                                buffer.chars().count().saturating_sub(THINK_CLOSE.len());
                            let flush_byte_idx = buffer
                                .char_indices()
                                .nth(flush_chars)
                                .map(|(idx, _)| idx)
                                .unwrap_or(buffer.len());
                            let text = buffer[..flush_byte_idx].to_string();
                            if let Some(ref cb) = handlers.on_thinking {
                                if !text.is_empty() {
                                    cb(&text);
                                }
                            }
                            buffer.drain(..flush_byte_idx);
                        }
                        break;
                    }
                } else if let Some(start) = buffer.find(THINK_OPEN) {
                    let text = &buffer[..start];
                    if let Some(ref cb) = handlers.on_token {
                        if !text.is_empty() {
                            cb(text);
                        }
                    }
                    if let Some(ref cb) = handlers.on_start_thinking {
                        cb();
                    }
                    buffer.drain(..start + THINK_OPEN.len());
                    *in_think = true;
                } else {
                    if let Some(pos) = buffer.rfind('<') {
                        if pos > 0 {
                            let text = &buffer[..pos];
                            if let Some(ref cb) = handlers.on_token {
                                if !text.is_empty() {
                                    cb(text);
                                }
                            }
                            buffer.drain(..pos);
                        }
                        break;
                    } else {
                        if let Some(ref cb) = handlers.on_token {
                            if !buffer.is_empty() {
                                cb(buffer);
                            }
                        }
                        buffer.clear();
                        break;
                    }
                }
            }
        }

        futures::pin_mut!(stream);
        let mut sse_buffer = String::new();
        'outer: while let Some(chunk_result) = stream.next().await {
            let chunk_str = chunk_result?;
            if chunk_str.trim().is_empty() {
                continue;
            }
            tracing::trace!("raw chunk: {}", chunk_str);
            sse_buffer.push_str(&chunk_str);
            while let Some(idx) = sse_buffer.find("\n\n") {
                let mut event = sse_buffer[..idx].to_string();
                sse_buffer.drain(..idx + 2);
                event = event.trim().trim_start_matches("data: ").to_string();
                if event == "[DONE]" {
                    break 'outer;
                }
                match serde_json::from_str::<LLMChunkResponse>(&event) {
                    Ok(chunk) => {
                        raw_chunks.push(chunk.clone());
                        let reasoning_text = chunk
                            .choices
                            .get(0)
                            .and_then(|cc| cc.delta.reasoning_content.clone());
                        let content_text =
                            chunk.choices.get(0).and_then(|cc| cc.delta.content.clone());

                        if let Some(text) = reasoning_text {
                            if !in_reasoning {
                                if let Some(ref cb) = handlers.on_start_thinking {
                                    cb();
                                }
                                in_reasoning = true;
                                in_think = true;
                            }
                            if let Some(ref cb) = handlers.on_thinking {
                                if !text.is_empty() {
                                    cb(&text);
                                }
                            }
                        }

                        if let Some(token_text) = content_text {
                            if in_reasoning {
                                if let Some(ref cb) = handlers.on_stop_thinking {
                                    cb();
                                }
                                in_reasoning = false;
                                in_think = false;
                            }
                            partial_content.push_str(&token_text);
                            buffer.push_str(&token_text);
                            process_buffer(&mut buffer, &mut in_think, handlers);
                        }

                        if let Some(reason) =
                            chunk.choices.get(0).and_then(|cc| cc.finish_reason.clone())
                        {
                            if reason == "tool_calls" || reason == "stop" || reason == "length" {
                                break 'outer;
                            }
                        }
                    }
                    Err(e) => {
                        tracing::warn!("Failed to parse chunk as JSON: {} -- raw: {}", e, event);
                    }
                }
            }
        }
        // Flush any remaining buffered data
        if !buffer.is_empty() {
            if in_think {
                if let Some(ref cb) = handlers.on_thinking {
                    cb(&buffer);
                }
            } else if let Some(ref cb) = handlers.on_token {
                cb(&buffer);
            }
        }
        if in_reasoning {
            if let Some(ref cb) = handlers.on_stop_thinking {
                cb();
            }
        }
        tracing::info!("stream complete");
        tracing::trace!("stream complete; {} chunks collected", raw_chunks.len());
        Ok((raw_chunks, partial_content))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::client::types::Message;
    use futures::StreamExt;
    use serde_json::json;

    fn start_server(
        body: &'static [u8],
        status: u16,
        content_type: Option<&str>,
    ) -> (std::net::SocketAddr, std::thread::JoinHandle<()>) {
        use std::io::{Read, Write};
        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
        let addr = listener.local_addr().unwrap();
        let body_vec = body.to_vec();
        let ct_header = content_type
            .map(|ct| format!("Content-Type: {}\r\n", ct))
            .unwrap_or_default();
        let handle = std::thread::spawn(move || {
            let (mut stream, _) = listener.accept().unwrap();
            let mut _buf = [0u8; 1024];
            let _ = stream.read(&mut _buf); // ignore request
            let status_line = match status {
                200 => "200 OK",
                500 => "500 Internal Server Error",
                _ => "200 OK",
            };
            let response = format!(
                "HTTP/1.1 {}\r\nContent-Length: {}\r\n{}\r\n",
                status_line,
                body_vec.len(),
                ct_header
            );
            stream.write_all(response.as_bytes()).unwrap();
            stream.write_all(&body_vec).unwrap();
        });
        (addr, handle)
    }

    #[tokio::test]
    async fn http_call_success() {
        let (addr, handle) = start_server(b"ok", 200, None);
        let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
        let empty = json!({});
        let res = llm
            .http_call(Endpoint::ChatCompletion, Method::Post, Some(&empty))
            .await
            .unwrap();
        handle.join().unwrap();
        assert_eq!(res, "ok");
    }

    #[tokio::test]
    async fn http_call_error() {
        let (addr, handle) = start_server(b"fail", 500, None);
        let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
        let empty = json!({});
        let err = llm
            .http_call(Endpoint::ChatCompletion, Method::Post, Some(&empty))
            .await
            .unwrap_err();
        handle.join().unwrap();
        match err {
            crate::Error::HttpStatus(code, body) => {
                assert_eq!(code.as_u16(), 500);
                assert_eq!(body, "fail");
            }
            _ => panic!("unexpected error"),
        }
    }

    #[tokio::test]
    async fn http_call_streamed_returns_chunks() {
        const BODY: &str = "data: {\"a\":1}\n\ndata: [DONE]\n\n";
        let (addr, handle) = start_server(BODY.as_bytes(), 200, Some("text/event-stream"));
        let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
        let empty = json!({});
        let mut stream = llm
            .http_call_streamed(Endpoint::ChatCompletion, Method::Post, Some(&empty))
            .await
            .unwrap();
        let mut collected = Vec::new();
        while let Some(chunk) = stream.next().await {
            collected.push(chunk.unwrap());
        }
        handle.join().unwrap();
        assert_eq!(collected, vec!["{\"a\":1}\n\n[DONE]\n\n"]);
    }

    #[tokio::test]
    async fn stream_llm_parses_sse() {
        const BODY: &str = "data: {\"id\":\"1\",\"object\":\"chunk\",\"created\":0,\"model\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n";
        let (addr, handle) = start_server(BODY.as_bytes(), 200, Some("text/event-stream"));
        let llm = super::super::PlainLLM::new(&format!("http://{}", addr), "t");
        let mut req = ChatCompletionRequest::new("m".into(), vec![Message::new("user", "hi")]);
        req.stream = true;
        let (chunks, content) = llm
            .stream_llm(&req, &LLMEventHandlers::default())
            .await
            .unwrap();
        handle.join().unwrap();
        assert_eq!(chunks.len(), 1);
        assert_eq!(content, "hi");
    }
}