use std::pin::Pin;
use async_stream::try_stream;
use futures_util::{Stream, StreamExt};
use reqwest::Method;
use serde::Serialize;
mod types;
pub use types::{
ChatChoice, ChatDelta, ChatMessage, ChatResponseMessage, ChatStream, ChatStreamChoice,
ChatStreamEvent, ChatTool, ContentPart, CreateChatCompletionArgs, CreateChatCompletionResponse,
FunctionDefinition, MediaSource, MessageContent, TokenUsage, ToolCall, ToolCallDelta,
ToolCallDeltaFunction, ToolCallFunction,
};
use crate::config::ServiceBase;
use crate::{Client, RekaError, Result};
#[derive(Clone)]
pub struct ChatClient {
client: Client,
}
impl ChatClient {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
pub async fn create(
&self,
args: &CreateChatCompletionArgs,
) -> Result<CreateChatCompletionResponse> {
self.client
.request(ServiceBase::Chat, Method::POST, "/chat/completions")
.json(&ChatCompletionBody::standard(args))
.send_json()
.await
}
pub async fn stream(&self, args: &CreateChatCompletionArgs) -> Result<ChatStream> {
let events = self
.client
.request(ServiceBase::Chat, Method::POST, "/chat/completions")
.accept("text/event-stream")
.json(&ChatCompletionBody::streaming(args))
.send_sse()
.await?;
let stream = try_stream! {
let mut events = events;
while let Some(event) = events.next().await {
let event = event?;
if event.data == "[DONE]" {
break;
}
let chunk = serde_json::from_str::<ChatStreamEvent>(&event.data)
.map_err(|source| RekaError::decode("/chat/completions", event.data, source))?;
yield chunk;
}
};
Ok(ChatStream {
inner: Box::pin(stream) as Pin<Box<dyn Stream<Item = Result<ChatStreamEvent>> + Send>>,
})
}
}
#[derive(Serialize)]
struct ChatCompletionBody<'a> {
#[serde(flatten)]
args: &'a CreateChatCompletionArgs,
#[serde(skip_serializing_if = "is_false")]
stream: bool,
}
impl<'a> ChatCompletionBody<'a> {
fn standard(args: &'a CreateChatCompletionArgs) -> Self {
Self {
args,
stream: false,
}
}
fn streaming(args: &'a CreateChatCompletionArgs) -> Self {
Self { args, stream: true }
}
}
fn is_false(value: &bool) -> bool {
!*value
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::ChatCompletionBody;
use crate::{ChatMessage, CreateChatCompletionArgs, ModelId};
#[test]
fn chat_completion_payload_controls_stream_without_polluting_public_request() {
let args = CreateChatCompletionArgs::new(
ModelId::flash(),
vec![ChatMessage::user("Stream this response.")],
)
.with_max_tokens(32);
let standard = serde_json::to_value(ChatCompletionBody::standard(&args))
.expect("standard payload should serialize");
let streaming = serde_json::to_value(ChatCompletionBody::streaming(&args))
.expect("streaming payload should serialize");
let bare = serde_json::to_value(&args).expect("args should serialize");
assert_eq!(streaming["stream"], true);
assert_eq!(streaming["model"], "reka-flash");
assert_eq!(standard.get("stream"), None);
assert_eq!(bare.get("stream"), None);
assert_eq!(bare["max_tokens"], json!(32));
}
}