use std::pin::Pin;
use crate::parser::{GenericStreamParser, StreamEventExt};
use crate::types::Thinking;
use crate::Result;
use bytes::Bytes;
use futures::Stream;
use ollama_sdk_macros::FromBytes;
use serde::{Deserialize, Serialize};
use super::{Role, ThinkingLevel};
#[derive(Serialize, Default, Debug, Clone)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatRequestMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolSpec>>,
#[serde(default)]
pub think: Thinking,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum ChatRequestMessage {
Message(RegularChatRequestMessage),
ToolCallResult(ToolCallResultMessage),
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct RegularChatRequestMessage {
pub role: Role,
pub content: String,
#[serde(default)]
pub tool_calls: Vec<FunctionalTool>,
}
impl RegularChatRequestMessage {
pub fn new(role: Role, content: String) -> Self {
Self {
role,
content,
tool_calls: Vec::new(),
}
}
pub fn add_tool_call(mut self, tool: FunctionalTool) -> Self {
self.tool_calls.push(tool);
self
}
pub fn to_chat_request_message(self) -> ChatRequestMessage {
ChatRequestMessage::Message(self)
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ToolCallResultMessage {
pub role: Role,
pub name: String,
pub content: String,
pub tool_call_id: String,
}
impl ToolCallResultMessage {
pub fn new(name: String, content: String, tool_call_id: String) -> Self {
Self {
role: Role::Tool,
name,
content,
tool_call_id,
}
}
pub fn to_chat_request_message(self) -> ChatRequestMessage {
ChatRequestMessage::ToolCallResult(self)
}
}
#[derive(Serialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ToolSpec {
Function { function: FunctionalTool },
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct FunctionalTool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Deserialize, Serialize, Default, FromBytes, Debug, Clone)]
pub struct ChatResponse {
pub model: String,
#[serde(default)]
pub created_at: String,
pub message: ChatResponseMessage,
pub done: bool,
}
#[derive(Deserialize, Serialize, Default, Debug, Clone)]
pub struct ChatResponseMessage {
pub role: Role,
pub content: String,
#[serde(default)]
pub thinking: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ToolCall {
pub id: String,
pub function: FunctionInvocation,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct FunctionInvocation {
pub index: Option<usize>,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Serialize, Default, Debug, Clone)]
pub struct SimpleChatRequest {
pub model: String,
pub messages: Vec<ChatRequestMessage>,
pub think: Thinking,
}
impl SimpleChatRequest {
pub fn new(model: String) -> Self {
Self {
model,
messages: Vec::new(),
think: Thinking::default(),
}
}
pub fn add_message(mut self, message: RegularChatRequestMessage) -> Self {
self.messages.push(ChatRequestMessage::Message(message));
self
}
pub fn add_tool_call_result(mut self, message: ToolCallResultMessage) -> Self {
self.messages
.push(ChatRequestMessage::ToolCallResult(message));
self
}
pub fn enable_thinking(mut self) -> Self {
self.think = Thinking::Boolean(true);
self
}
pub fn disable_thinking(mut self) -> Self {
self.think = Thinking::Boolean(false);
self
}
pub fn set_thinking_level(mut self, level: ThinkingLevel) -> Self {
self.think = Thinking::Level(level);
self
}
}
#[derive(Serialize, Default, Debug, Clone)]
pub struct StreamingChatRequest {
pub model: String,
pub messages: Vec<ChatRequestMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolSpec>>,
pub think: Thinking,
}
impl StreamingChatRequest {
pub fn new(model: String) -> Self {
Self {
model,
messages: Vec::new(),
tools: None,
think: Thinking::default(),
}
}
pub fn add_message(mut self, message: ChatRequestMessage) -> Self {
self.messages.push(message);
self
}
pub fn add_regular_message(mut self, message: RegularChatRequestMessage) -> Self {
self.messages.push(ChatRequestMessage::Message(message));
self
}
pub fn add_tool_call_result(mut self, message: ToolCallResultMessage) -> Self {
self.messages
.push(ChatRequestMessage::ToolCallResult(message));
self
}
pub fn enable_thinking(mut self) -> Self {
self.think = Thinking::Boolean(true);
self
}
pub fn disable_thinking(mut self) -> Self {
self.think = Thinking::Boolean(false);
self
}
pub fn set_thinking_level(mut self, level: ThinkingLevel) -> Self {
self.think = Thinking::Level(level);
self
}
pub fn tools(mut self, tools: Vec<ToolSpec>) -> Self {
self.tools = Some(tools);
self
}
}
impl From<SimpleChatRequest> for ChatRequest {
fn from(value: SimpleChatRequest) -> Self {
ChatRequest {
model: value.model,
messages: value.messages,
stream: Some(false),
think: value.think,
tools: None,
}
}
}
impl From<StreamingChatRequest> for ChatRequest {
fn from(value: StreamingChatRequest) -> Self {
ChatRequest {
model: value.model,
messages: value.messages,
stream: Some(true),
think: value.think,
tools: value.tools,
}
}
}
#[derive(Deserialize, Serialize, Debug)]
pub enum ChatStreamEvent {
Message(ChatResponse),
Error(String),
Partial {
partial: String,
error: Option<String>,
},
}
pub struct ChatStream {
pub inner: Pin<Box<dyn Stream<Item = Result<ChatStreamEvent>> + Send>>,
}
impl Stream for ChatStream {
type Item = Result<ChatStreamEvent>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl ChatStream {
pub fn from_bytes_stream<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes>> + Send + Unpin + 'static,
{
let parser = GenericStreamParser::<S, ChatResponse, ChatStreamEvent>::new(stream);
ChatStream {
inner: Box::pin(parser),
}
}
}
impl StreamEventExt<ChatResponse> for ChatStreamEvent {
fn from_message(msg: ChatResponse) -> Self {
ChatStreamEvent::Message(msg)
}
fn from_error(err: String) -> Self {
ChatStreamEvent::Error(err)
}
fn partial(partial: String, error: Option<String>) -> Self {
ChatStreamEvent::Partial { partial, error }
}
}