use std::collections::HashMap;
use openai_client_base::models::create_message_request::Role as MessageRole;
use openai_client_base::models::{
assistant_tools_code, assistant_tools_file_search_type_only, AssistantToolsCode,
AssistantToolsFileSearchTypeOnly, CreateMessageRequest, CreateMessageRequestAttachmentsInner,
CreateMessageRequestAttachmentsInnerToolsInner, CreateThreadRequest,
};
use serde_json::Value;
use crate::{Builder, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttachmentTool {
CodeInterpreter,
FileSearch,
}
impl AttachmentTool {
fn to_api(self) -> CreateMessageRequestAttachmentsInnerToolsInner {
match self {
Self::CodeInterpreter => {
CreateMessageRequestAttachmentsInnerToolsInner::AssistantToolsCode(Box::new(
AssistantToolsCode::new(assistant_tools_code::Type::CodeInterpreter),
))
}
Self::FileSearch => {
CreateMessageRequestAttachmentsInnerToolsInner::AssistantToolsFileSearchTypeOnly(
Box::new(AssistantToolsFileSearchTypeOnly::new(
assistant_tools_file_search_type_only::Type::FileSearch,
)),
)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageAttachment {
file_id: String,
tools: Vec<AttachmentTool>,
}
impl MessageAttachment {
#[must_use]
pub fn for_code_interpreter(file_id: impl Into<String>) -> Self {
Self {
file_id: file_id.into(),
tools: vec![AttachmentTool::CodeInterpreter],
}
}
#[must_use]
pub fn for_file_search(file_id: impl Into<String>) -> Self {
Self {
file_id: file_id.into(),
tools: vec![AttachmentTool::FileSearch],
}
}
#[must_use]
pub fn with_tool(mut self, tool: AttachmentTool) -> Self {
if !self.tools.contains(&tool) {
self.tools.push(tool);
}
self
}
fn into_api(self) -> CreateMessageRequestAttachmentsInner {
let mut inner = CreateMessageRequestAttachmentsInner::new();
inner.file_id = Some(self.file_id);
if !self.tools.is_empty() {
let tools = self.tools.into_iter().map(AttachmentTool::to_api).collect();
inner.tools = Some(tools);
}
inner
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
enum MetadataState {
#[default]
Unset,
Present(HashMap<String, String>),
ExplicitNull,
}
impl MetadataState {
fn upsert(&mut self, key: String, value: String) {
match self {
MetadataState::Unset | MetadataState::ExplicitNull => {
let mut map = HashMap::new();
map.insert(key, value);
*self = MetadataState::Present(map);
}
MetadataState::Present(map) => {
map.insert(key, value);
}
}
}
fn replace(&mut self, metadata: HashMap<String, String>) {
*self = MetadataState::Present(metadata);
}
fn clear(&mut self) {
*self = MetadataState::ExplicitNull;
}
#[allow(clippy::option_option)]
fn into_option(self) -> Option<Option<HashMap<String, String>>> {
match self {
MetadataState::Unset => None,
MetadataState::Present(map) if map.is_empty() => None,
MetadataState::Present(map) => Some(Some(map)),
MetadataState::ExplicitNull => Some(None),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ThreadMessageBuilder {
role: MessageRole,
content: String,
attachments: Vec<MessageAttachment>,
metadata: MetadataState,
}
impl ThreadMessageBuilder {
#[must_use]
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: content.into(),
attachments: Vec::new(),
metadata: MetadataState::Unset,
}
}
#[must_use]
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: content.into(),
attachments: Vec::new(),
metadata: MetadataState::Unset,
}
}
#[must_use]
pub fn content(mut self, content: impl Into<String>) -> Self {
self.content = content.into();
self
}
#[must_use]
pub fn attachment(mut self, attachment: MessageAttachment) -> Self {
self.attachments.push(attachment);
self
}
#[must_use]
pub fn attachments<I>(mut self, attachments: I) -> Self
where
I: IntoIterator<Item = MessageAttachment>,
{
self.attachments.extend(attachments);
self
}
#[must_use]
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.upsert(key.into(), value.into());
self
}
#[must_use]
pub fn metadata_map(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata.replace(metadata);
self
}
#[must_use]
pub fn clear_metadata(mut self) -> Self {
self.metadata.clear();
self
}
}
impl Builder<CreateMessageRequest> for ThreadMessageBuilder {
fn build(self) -> Result<CreateMessageRequest> {
let mut request = CreateMessageRequest::new(self.role, Value::String(self.content));
if !self.attachments.is_empty() {
let attachments = self
.attachments
.into_iter()
.map(MessageAttachment::into_api)
.collect();
request.attachments = Some(Some(attachments));
}
request.metadata = self.metadata.into_option();
Ok(request)
}
}
impl ThreadMessageBuilder {
#[must_use]
pub fn finish(self) -> CreateMessageRequest {
self.build()
.expect("thread message builder should be infallible")
}
}
#[derive(Debug, Clone, Default)]
pub struct ThreadRequestBuilder {
messages: Vec<CreateMessageRequest>,
metadata: MetadataState,
}
impl ThreadRequestBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn user_message(mut self, content: impl Into<String>) -> Self {
self.messages
.push(ThreadMessageBuilder::user(content).finish());
self
}
#[must_use]
pub fn assistant_message(mut self, content: impl Into<String>) -> Self {
self.messages
.push(ThreadMessageBuilder::assistant(content).finish());
self
}
#[must_use]
pub fn message_request(mut self, message: CreateMessageRequest) -> Self {
self.messages.push(message);
self
}
pub fn message_builder(mut self, builder: ThreadMessageBuilder) -> Result<Self> {
self.messages.push(builder.build()?);
Ok(self)
}
#[must_use]
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.upsert(key.into(), value.into());
self
}
#[must_use]
pub fn metadata_map(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata.replace(metadata);
self
}
#[must_use]
pub fn clear_metadata(mut self) -> Self {
self.metadata.clear();
self
}
#[must_use]
pub fn messages(&self) -> &[CreateMessageRequest] {
&self.messages
}
}
impl Builder<CreateThreadRequest> for ThreadRequestBuilder {
fn build(self) -> Result<CreateThreadRequest> {
let mut request = CreateThreadRequest::new();
if !self.messages.is_empty() {
request.messages = Some(self.messages);
}
request.metadata = self.metadata.into_option();
Ok(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builds_basic_user_message() {
let builder = ThreadMessageBuilder::user("Hello");
let message = builder.build().expect("builder should succeed");
assert_eq!(message.role, MessageRole::User);
assert_eq!(message.content, Value::String("Hello".to_string()));
assert!(message.attachments.is_none());
assert!(message.metadata.is_none());
}
#[test]
fn builds_message_with_attachment() {
let attachment = MessageAttachment::for_code_interpreter("file-123");
let message = ThreadMessageBuilder::user("process this")
.attachment(attachment)
.build()
.expect("builder should succeed");
let attachments = message.attachments.unwrap().unwrap();
assert_eq!(attachments.len(), 1);
assert_eq!(attachments[0].file_id.as_deref(), Some("file-123"));
assert!(attachments[0].tools.as_ref().is_some());
}
#[test]
fn builds_thread_with_metadata() {
let thread = ThreadRequestBuilder::new()
.user_message("Hi there")
.metadata("topic", "support")
.build()
.expect("builder should succeed");
assert!(thread.messages.is_some());
let metadata = thread.metadata.unwrap().unwrap();
assert_eq!(metadata.get("topic"), Some(&"support".to_string()));
}
#[test]
fn can_explicitly_clear_metadata() {
let thread = ThreadRequestBuilder::new()
.metadata("foo", "bar")
.clear_metadata()
.build()
.expect("builder should succeed");
assert!(thread.metadata.is_some());
assert!(thread.metadata.unwrap().is_none());
}
#[test]
fn accepts_custom_message_builder() {
let message_builder = ThreadMessageBuilder::assistant("Hello").metadata("tone", "friendly");
let thread = ThreadRequestBuilder::new()
.message_builder(message_builder)
.expect("builder should succeed")
.build()
.expect("thread build should succeed");
let message = thread.messages.unwrap();
assert_eq!(message.len(), 1);
assert_eq!(message[0].role, MessageRole::Assistant);
let metadata = message[0].metadata.clone().unwrap().unwrap();
assert_eq!(metadata.get("tone"), Some(&"friendly".to_string()));
}
}