openmodex 0.1.1

Official Rust SDK for the OpenModex API
Documentation
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::Bytes;
use futures::Stream;
use pin_project_lite::pin_project;
use reqwest::Response;

use crate::error::Error;
use crate::types::ChatCompletionChunk;

pin_project! {
    /// A stream of [`ChatCompletionChunk`]s from a streaming chat completion request.
    ///
    /// Implements [`futures::Stream`] so it can be used with `StreamExt::next()`.
    pub struct ChatCompletionStream {
        buffer: String,
        done: bool,
        #[pin]
        byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
    }
}

impl ChatCompletionStream {
    /// Create a new stream from an HTTP response.
    pub(crate) fn new(response: Response) -> Self {
        let byte_stream = response.bytes_stream();
        Self {
            buffer: String::new(),
            done: false,
            byte_stream: Box::pin(byte_stream),
        }
    }
}

impl Stream for ChatCompletionStream {
    type Item = Result<ChatCompletionChunk, Error>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        if *this.done {
            return Poll::Ready(None);
        }

        loop {
            // Try to parse any complete SSE events from the buffer.
            if let Some(result) = try_parse_sse_event(this.buffer, this.done) {
                return Poll::Ready(Some(result));
            }

            // Read more data from the byte stream.
            match this.byte_stream.as_mut().poll_next(cx) {
                Poll::Ready(Some(Ok(bytes))) => {
                    match std::str::from_utf8(&bytes) {
                        Ok(text) => this.buffer.push_str(text),
                        Err(e) => {
                            return Poll::Ready(Some(Err(Error::Stream(format!(
                                "invalid UTF-8 in SSE stream: {e}"
                            )))));
                        }
                    }
                }
                Poll::Ready(Some(Err(e))) => {
                    return Poll::Ready(Some(Err(Error::Http(e))));
                }
                Poll::Ready(None) => {
                    *this.done = true;
                    return Poll::Ready(None);
                }
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

/// Try to extract and parse one SSE data event from the buffer.
///
/// Returns `None` if no complete event is available yet.
fn try_parse_sse_event(
    buffer: &mut String,
    done: &mut bool,
) -> Option<Result<ChatCompletionChunk, Error>> {
    loop {
        // Look for a complete line ending with \n.
        let newline_pos = buffer.find('\n')?;
        let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
        buffer.drain(..=newline_pos);

        // Skip empty lines and comments.
        if line.is_empty() || line.starts_with(':') {
            continue;
        }

        // Parse SSE data lines.
        if let Some(payload) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) {
            let payload = payload.trim();

            // Check for the done sentinel.
            if payload == "[DONE]" {
                *done = true;
                return None;
            }

            match serde_json::from_str::<ChatCompletionChunk>(payload) {
                Ok(chunk) => return Some(Ok(chunk)),
                Err(e) => return Some(Err(Error::Json(e))),
            }
        }

        // Skip other SSE fields (event:, id:, retry:, etc.)
    }
}