use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::operation::common::{Parameters, StreamOptions};
use crate::operation::request::RequestTrait;
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct GenerationParam {
#[builder(setter(into, strip_option))]
pub model: String,
pub input: Input,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option))]
#[builder(default=None)]
pub parameters: Option<Parameters>,
#[builder(setter(into, strip_option))]
#[builder(default=Some(false))]
pub stream: Option<bool>,
#[builder(setter(into, strip_option))]
#[builder(default=None)]
pub stream_options: Option<StreamOptions>,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct Input {
#[builder(setter(custom))]
pub messages: Vec<Message>,
}
impl InputBuilder {
pub fn messages(&mut self, value: Vec<Message>) -> &mut Self {
self.messages = Some(value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
#[derive(Default)]
pub enum Message {
#[default]
None,
System(SystemMessage),
User(UserMessage),
Assistant(AssistantMessage),
Tool(ToolMessage),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct MessageBuilder {
pub role: String,
pub content: String,
pub partial: Option<bool>,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Error)]
pub enum MessageBuilderError {
#[error("Invalid role")]
InvalidRole,
}
impl MessageBuilder {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
partial: Some(false),
tool_calls: None,
tool_call_id: None,
}
}
pub fn role(&mut self, value: impl Into<String>) -> &mut Self {
self.role = value.into();
self
}
pub fn system(&mut self) -> &mut Self {
self.role("system")
}
pub fn user(&mut self) -> &mut Self {
self.role("user")
}
pub fn assistant(&mut self) -> &mut Self {
self.role("assistant")
}
pub fn tool(&mut self) -> &mut Self {
self.role("tool")
}
pub fn content(&mut self, value: impl Into<String>) -> &mut Self {
self.content = value.into();
self
}
pub fn partial(&mut self, value: bool) -> &mut Self {
self.partial = Some(value);
self
}
pub fn tool_call_id(&mut self, value: impl Into<String>) -> &mut Self {
self.tool_call_id = Some(value.into());
self
}
pub fn tool_calls(&mut self, value: Vec<ToolCall>) -> &mut Self {
self.tool_calls = Some(value);
self
}
pub fn build(&self) -> Result<Message, MessageBuilderError> {
match self.role.as_ref() {
"system" => Ok(Message::System(SystemMessage {
role: self.role.clone(),
content: self.clone().content,
})),
"user" => Ok(Message::User(UserMessage {
role: self.role.clone(),
content: self.clone().content,
})),
"assistant" => Ok(Message::Assistant(AssistantMessage {
role: self.role.clone(),
content: self.content.clone(),
partial: self.partial,
tool_calls: self.tool_calls.clone(),
})),
"tool" => Ok(Message::Tool(ToolMessage {
role: self.role.clone(),
content: self.content.clone(),
tool_call_id: self.tool_call_id.clone(),
})),
_ => Err(MessageBuilderError::InvalidRole),
}
}
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct SystemMessage {
#[builder(setter(into), default = "\"system\".to_string()")]
pub role: String,
#[builder(setter(into))]
pub content: String,
}
impl From<SystemMessage> for Message {
fn from(value: SystemMessage) -> Self {
Self::System(value)
}
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct UserMessage {
#[builder(setter(into), default = "\"user\".to_string()")]
pub role: String,
#[builder(setter(into))]
pub content: String,
}
impl From<UserMessage> for Message {
fn from(value: UserMessage) -> Self {
Self::User(value)
}
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct AssistantMessage {
#[builder(setter(into), default = "\"assistant\".to_string()")]
pub role: String,
#[builder(setter(into))]
pub content: String,
#[builder(setter(into, strip_option))]
#[builder(default=Some(false))]
pub partial: Option<bool>,
#[builder(setter(into, strip_option))]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl From<AssistantMessage> for Message {
fn from(value: AssistantMessage) -> Self {
Self::Assistant(value)
}
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
#[builder(setter(into))]
pub id: String,
#[builder(setter(into))]
#[serde(rename = "type")]
pub type_: String,
pub function: Function,
pub index: i32,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct Function {
#[builder(setter(into))]
pub name: String,
#[builder(setter(into))]
pub arguments: String,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct ToolMessage {
#[builder(setter(into), default = "\"tool\".to_string()")]
pub role: String,
#[builder(setter(into))]
pub content: String,
#[builder(setter(into))]
pub tool_call_id: Option<String>,
}
impl From<ToolMessage> for Message {
fn from(value: ToolMessage) -> Self {
Self::Tool(value)
}
}
impl RequestTrait for GenerationParam {
type P = Parameters;
fn model(&self) -> &str {
&self.model
}
fn parameters(&self) -> Option<&Self::P> {
self.parameters.as_ref()
}
}