use std::collections::BTreeMap;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::{ModelId, Result};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateChatCompletionArgs {
pub model: ModelId,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl CreateChatCompletionArgs {
pub fn new(model: ModelId, messages: Vec<ChatMessage>) -> Self {
Self {
model,
messages,
temperature: None,
max_tokens: None,
stop: None,
tools: None,
tool_choice: None,
extra: BTreeMap::new(),
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_stop<I, S>(mut self, stop: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.stop = Some(stop.into_iter().map(Into::into).collect());
self
}
pub fn with_tools(mut self, tools: Vec<ChatTool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, tool_choice: impl Into<String>) -> Self {
self.tool_choice = Some(tool_choice.into());
self
}
pub fn insert_extra(mut self, key: impl Into<String>, value: Value) -> Self {
self.extra.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<MessageContent>) -> Self {
Self {
role: role.into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
extra: BTreeMap::new(),
}
}
pub fn user(content: impl Into<MessageContent>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<MessageContent>) -> Self {
Self::new("system", content)
}
pub fn tool(tool_call_id: impl Into<String>, content: impl Into<MessageContent>) -> Self {
Self {
role: "tool".to_string(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
extra: BTreeMap::new(),
}
}
pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: "assistant".to_string(),
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
extra: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl From<String> for MessageContent {
fn from(value: String) -> Self {
Self::Text(value)
}
}
impl From<&str> for MessageContent {
fn from(value: &str) -> Self {
Self::Text(value.to_string())
}
}
impl From<Vec<ContentPart>> for MessageContent {
fn from(value: Vec<ContentPart>) -> Self {
Self::Parts(value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: MediaSource },
VideoUrl { video_url: MediaSource },
AudioUrl { audio_url: MediaSource },
PdfUrl { pdf_url: MediaSource },
}
impl ContentPart {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: MediaSource::url(url),
}
}
pub fn image_url_object(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: MediaSource::url_object(url),
}
}
pub fn video_url(url: impl Into<String>) -> Self {
Self::VideoUrl {
video_url: MediaSource::url(url),
}
}
pub fn audio_url(url: impl Into<String>) -> Self {
Self::AudioUrl {
audio_url: MediaSource::url(url),
}
}
pub fn pdf_url(url: impl Into<String>) -> Self {
Self::PdfUrl {
pdf_url: MediaSource::url(url),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MediaSource {
Url(String),
UrlObject { url: String },
}
impl MediaSource {
pub fn url(url: impl Into<String>) -> Self {
Self::Url(url.into())
}
pub fn url_object(url: impl Into<String>) -> Self {
Self::UrlObject { url: url.into() }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatTool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl ChatTool {
pub fn function(function: FunctionDefinition) -> Self {
Self {
tool_type: "function".to_string(),
function,
extra: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl FunctionDefinition {
pub fn new(name: impl Into<String>, parameters: Value) -> Self {
Self {
name: name.into(),
description: None,
parameters,
extra: BTreeMap::new(),
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateChatCompletionResponse {
pub id: String,
pub choices: Vec<ChatChoice>,
#[serde(default)]
pub created: Option<u64>,
#[serde(default)]
pub model: Option<ModelId>,
#[serde(default)]
pub usage: Option<TokenUsage>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
pub struct ChatStream {
pub(crate) inner: Pin<Box<dyn Stream<Item = Result<ChatStreamEvent>> + Send>>,
}
impl Stream for ChatStream {
type Item = Result<ChatStreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatStreamEvent {
#[serde(default)]
pub id: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(default)]
pub created: Option<u64>,
#[serde(default)]
pub model: Option<ModelId>,
#[serde(default)]
pub usage: Option<TokenUsage>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatStreamChoice {
#[serde(default)]
pub index: Option<u32>,
#[serde(default)]
pub finish_reason: Option<String>,
pub delta: ChatDelta,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatDelta {
#[serde(default)]
pub role: Option<String>,
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(default)]
pub reasoning_content: Option<String>,
#[serde(default)]
pub reasoning_steps: Option<Vec<Value>>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallDelta {
#[serde(default)]
pub index: Option<u32>,
#[serde(default)]
pub id: Option<String>,
#[serde(rename = "type", default)]
pub tool_type: Option<String>,
#[serde(default)]
pub function: Option<ToolCallDeltaFunction>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallDeltaFunction {
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub arguments: Option<String>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatChoice {
#[serde(default)]
pub index: Option<u32>,
#[serde(default)]
pub finish_reason: Option<String>,
pub message: ChatResponseMessage,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatResponseMessage {
pub role: String,
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default)]
pub annotations: Option<Vec<Value>>,
#[serde(default)]
pub reasoning_content: Option<String>,
#[serde(default)]
pub reasoning_steps: Option<Vec<Value>>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolCallFunction,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenUsage {
#[serde(default)]
pub prompt_tokens: Option<u32>,
#[serde(default)]
pub completion_tokens: Option<u32>,
#[serde(default)]
pub total_tokens: Option<u32>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
#[cfg(test)]
mod tests {
use serde_json::{Value, json};
use super::{
ChatMessage, ChatStreamEvent, ChatTool, ContentPart, CreateChatCompletionArgs,
FunctionDefinition, ToolCall, ToolCallFunction,
};
use crate::ModelId;
#[test]
fn serializes_multimodal_messages() {
let request = CreateChatCompletionArgs::new(
ModelId::flash(),
vec![ChatMessage::user(vec![
ContentPart::image_url_object("https://example.com/cat.jpg"),
ContentPart::text("Describe this image."),
])],
)
.with_max_tokens(256);
let json = serde_json::to_value(&request).expect("request should serialize");
assert_eq!(json["model"], "reka-flash");
assert_eq!(json["messages"][0]["role"], "user");
assert_eq!(json["messages"][0]["content"][0]["type"], "image_url");
assert_eq!(
json["messages"][0]["content"][0]["image_url"]["url"],
"https://example.com/cat.jpg"
);
assert_eq!(json["messages"][0]["content"][1]["type"], "text");
assert_eq!(
json["messages"][0]["content"][1]["text"],
"Describe this image."
);
}
#[test]
fn serializes_tool_calling_requests() {
let request = CreateChatCompletionArgs::new(
ModelId::flash(),
vec![ChatMessage::user("Is product a-12345 in stock right now?")],
)
.with_tools(vec![ChatTool::function(
FunctionDefinition::new(
"get_product_availability",
json!({
"type": "object",
"properties": {
"product_id": {
"type": "string"
}
},
"required": ["product_id"]
}),
)
.with_description("Determine whether or not a product is currently in stock."),
)])
.with_tool_choice("auto");
let json = serde_json::to_value(&request).expect("request should serialize");
assert_eq!(json["tool_choice"], "auto");
assert_eq!(json["tools"][0]["type"], "function");
assert_eq!(
json["tools"][0]["function"]["name"],
"get_product_availability"
);
}
#[test]
fn deserializes_tool_call_responses_with_null_content() {
let response: Value = json!({
"id": "chatcmpl_123",
"choices": [{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "get_product_availability",
"arguments": "{\"product_id\": \"a-12345\"}"
}
}]
}
}]
});
let parsed = serde_json::from_value::<super::CreateChatCompletionResponse>(response)
.expect("response should deserialize");
let tool_call = parsed.choices[0].message.tool_calls.as_ref().unwrap();
assert!(parsed.choices[0].message.content.is_none());
assert_eq!(tool_call[0].function.name, "get_product_availability");
}
#[test]
fn builds_assistant_and_tool_messages_for_follow_up_turns() {
let assistant = ChatMessage::assistant_with_tool_calls(vec![ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: ToolCallFunction {
name: "get_product_availability".to_string(),
arguments: "{\"product_id\": \"a-12345\"}".to_string(),
extra: Default::default(),
},
extra: Default::default(),
}]);
let tool = ChatMessage::tool("call_123", "{\"status\": \"AVAILABLE\"}");
let assistant_json = serde_json::to_value(&assistant).expect("assistant should serialize");
let tool_json = serde_json::to_value(&tool).expect("tool should serialize");
assert!(assistant_json["content"].is_null());
assert_eq!(assistant_json["tool_calls"][0]["id"], "call_123");
assert_eq!(tool_json["role"], "tool");
assert_eq!(tool_json["tool_call_id"], "call_123");
assert_eq!(tool_json["content"], "{\"status\": \"AVAILABLE\"}");
}
#[test]
fn deserializes_stream_chunks() {
let chunk = json!({
"id": "chatcmpl_123",
"choices": [{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello"
}
}]
});
let parsed = serde_json::from_value::<ChatStreamEvent>(chunk)
.expect("stream chunk should deserialize");
assert_eq!(parsed.choices[0].delta.role.as_deref(), Some("assistant"));
assert_eq!(parsed.choices[0].delta.content.as_deref(), Some("Hello"));
}
}