reka 0.1.0

Async Rust SDK for the Reka API.
Documentation
use std::pin::Pin;

use async_stream::try_stream;
use futures_util::{Stream, StreamExt};
use reqwest::Method;
use serde::Serialize;

mod types;

pub use types::{
    ChatChoice, ChatDelta, ChatMessage, ChatResponseMessage, ChatStream, ChatStreamChoice,
    ChatStreamEvent, ChatTool, ContentPart, CreateChatCompletionArgs, CreateChatCompletionResponse,
    FunctionDefinition, MediaSource, MessageContent, TokenUsage, ToolCall, ToolCallDelta,
    ToolCallDeltaFunction, ToolCallFunction,
};

use crate::config::ServiceBase;
use crate::{Client, RekaError, Result};

/// Handle for chat completions and streaming chat responses.
#[derive(Clone)]
pub struct ChatClient {
    client: Client,
}

impl ChatClient {
    pub(crate) fn new(client: Client) -> Self {
        Self { client }
    }

    /// Creates a non-streaming chat completion.
    pub async fn create(
        &self,
        args: &CreateChatCompletionArgs,
    ) -> Result<CreateChatCompletionResponse> {
        self.client
            .request(ServiceBase::Chat, Method::POST, "/chat/completions")
            .json(&ChatCompletionBody::standard(args))
            .send_json()
            .await
    }

    /// Creates a streaming chat completion.
    ///
    /// The returned [`ChatStream`](crate::ChatStream) yields typed SSE events
    /// until the API sends `[DONE]`.
    pub async fn stream(&self, args: &CreateChatCompletionArgs) -> Result<ChatStream> {
        let events = self
            .client
            .request(ServiceBase::Chat, Method::POST, "/chat/completions")
            .accept("text/event-stream")
            .json(&ChatCompletionBody::streaming(args))
            .send_sse()
            .await?;

        let stream = try_stream! {
            let mut events = events;

            while let Some(event) = events.next().await {
                let event = event?;
                if event.data == "[DONE]" {
                    break;
                }

                let chunk = serde_json::from_str::<ChatStreamEvent>(&event.data)
                    .map_err(|source| RekaError::decode("/chat/completions", event.data, source))?;

                yield chunk;
            }
        };

        Ok(ChatStream {
            inner: Box::pin(stream) as Pin<Box<dyn Stream<Item = Result<ChatStreamEvent>> + Send>>,
        })
    }
}

#[derive(Serialize)]
struct ChatCompletionBody<'a> {
    #[serde(flatten)]
    args: &'a CreateChatCompletionArgs,
    #[serde(skip_serializing_if = "is_false")]
    stream: bool,
}

impl<'a> ChatCompletionBody<'a> {
    fn standard(args: &'a CreateChatCompletionArgs) -> Self {
        Self {
            args,
            stream: false,
        }
    }

    fn streaming(args: &'a CreateChatCompletionArgs) -> Self {
        Self { args, stream: true }
    }
}

fn is_false(value: &bool) -> bool {
    !*value
}

#[cfg(test)]
mod tests {
    use serde_json::json;

    use super::ChatCompletionBody;
    use crate::{ChatMessage, CreateChatCompletionArgs, ModelId};

    #[test]
    fn chat_completion_payload_controls_stream_without_polluting_public_request() {
        let args = CreateChatCompletionArgs::new(
            ModelId::flash(),
            vec![ChatMessage::user("Stream this response.")],
        )
        .with_max_tokens(32);

        let standard = serde_json::to_value(ChatCompletionBody::standard(&args))
            .expect("standard payload should serialize");
        let streaming = serde_json::to_value(ChatCompletionBody::streaming(&args))
            .expect("streaming payload should serialize");
        let bare = serde_json::to_value(&args).expect("args should serialize");

        assert_eq!(streaming["stream"], true);
        assert_eq!(streaming["model"], "reka-flash");
        assert_eq!(standard.get("stream"), None);
        assert_eq!(bare.get("stream"), None);
        assert_eq!(bare["max_tokens"], json!(32));
    }
}